45 #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ 46 #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ 48 #include <dmlc/common.h> 49 #include <dmlc/json.h> 56 namespace auto_scheduler {
130 v->Visit(
"name", &name);
131 v->Visit(
"range", &range);
132 v->Visit(
"iter_kind", &iter_kind);
133 v->Visit(
"annotation", &annotation);
136 static constexpr
const char* _type_key =
"auto_scheduler.Iterator";
155 const std::vector<Iterator>* orig_iters =
nullptr);
173 virtual void WriteToRecord(dmlc::JSONWriter* writer)
const = 0;
175 static constexpr
const char* _type_key =
"auto_scheduler.Step";
259 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
273 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
283 static constexpr
const char* record_prefix_str =
"AN";
285 static constexpr
const char* _type_key =
"auto_scheduler.AnnotationStep";
319 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
346 static constexpr
const char* record_prefix_str =
"FU";
348 static constexpr
const char* _type_key =
"auto_scheduler.FuseStep";
370 explicit FuseStep(dmlc::JSONReader* reader);
383 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
389 void ApplyToState(
State* state)
const;
396 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
406 static constexpr
const char* record_prefix_str =
"PR";
408 static constexpr
const char* _type_key =
"auto_scheduler.PragmaStep";
431 explicit PragmaStep(dmlc::JSONReader* reader);
445 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
451 void ApplyToState(
State* state)
const;
458 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
468 static constexpr
const char* record_prefix_str =
"RE";
470 static constexpr
const char* _type_key =
"auto_scheduler.ReorderStep";
515 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
533 StageToAxesMap* stage_to_axes)
const;
543 static constexpr
const char* record_prefix_str =
"SP";
545 static constexpr
const char* _type_key =
"auto_scheduler.SplitStep";
571 explicit SplitStep(dmlc::JSONReader* reader);
587 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
623 static constexpr
const char* record_prefix_str =
"FSP";
625 static constexpr
const char* _type_key =
"auto_scheduler.FollowSplitStep";
642 FollowSplitStep(
int stage_id,
int iter_id,
int src_step_id,
int n_split);
668 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
704 static constexpr
const char* record_prefix_str =
"FFSP";
706 static constexpr
const char* _type_key =
"auto_scheduler.FollowFusedSplitStep";
725 bool factor_or_nparts);
747 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
753 void ApplyToState(
State* state)
const;
760 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
770 static constexpr
const char* record_prefix_str =
"SA";
772 static constexpr
const char* _type_key =
"auto_scheduler.StorageAlignStep";
811 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
821 void ApplyToState(
State* state)
const;
828 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
838 static constexpr
const char* record_prefix_str =
"CA";
840 static constexpr
const char* _type_key =
"auto_scheduler.ComputeAtStep";
856 ComputeAtStep(
int stage_id,
int target_stage_id,
int target_iter_id);
871 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
877 void ApplyToState(
State* state)
const;
885 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
895 static constexpr
const char* record_prefix_str =
"CI";
897 static constexpr
const char* _type_key =
"auto_scheduler.ComputeInlineStep";
926 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
936 void ApplyToState(
State* state)
const;
944 void ApplyToSchedule(
Array<te::Stage>* stages, StageToAxesMap* stage_to_axes)
const;
954 static constexpr
const char* record_prefix_str =
"CR";
956 static constexpr
const char* _type_key =
"auto_scheduler.ComputeRootStep";
996 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
1026 static constexpr
const char* record_prefix_str =
"CHR";
1028 static constexpr
const char* _type_key =
"auto_scheduler.CacheReadStep";
1067 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
1097 static constexpr
const char* record_prefix_str =
"CHW";
1099 static constexpr
const char* _type_key =
"auto_scheduler.CacheWriteStep";
1134 void WriteToRecord(dmlc::JSONWriter* writer)
const final;
1164 static constexpr
const char* record_prefix_str =
"RF";
1166 static constexpr
const char* _type_key =
"auto_scheduler.RfactorStep";
1182 RfactorStep(
int stage_id,
int iter_id,
int factor_iter_id);
1197 #endif // TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ String-aware ObjectRef hash functor.
Definition: base.h:50
Managed reference to ComputeRootStepNode.
Definition: transform_step.h:964
Optional< PrimExpr > extent
The extent length of the axis to split.
Definition: transform_step.h:506
void StepApplyToSchedule(const Step &step, Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule, const Array< Step > &transform_steps)
Apply a general step to tvm.schedule with runtime dynamic dispatching.
Managed reference to SplitStepNode.
Definition: transform_step.h:553
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
Definitions and helper macros for IR/AST nodes.
IteratorAnnotation annotation
The annotation type of this iterator.
Definition: transform_step.h:125
Pragma step that corresponds to te::Stage::pragma.
Definition: transform_step.h:376
int iter_id
The id of the iter to be split.
Definition: transform_step.h:581
int factor_iter_id
The position where the new iterator is placed.
Definition: transform_step.h:1132
void StepApplyToState(const Step &step, State *state, const ComputeDAG &dag)
Apply a general step to a State with runtime dynamic dispatching.
Array< Integer > after_ids
The iterator ids after reorder. This array should specify the order of all iterators.
Definition: transform_step.h:443
Fuse step that corresponds to te::Stage::fuse.
Definition: transform_step.h:314
Managed reference to AnnotationStepNode.
Definition: transform_step.h:293
Cache read step that corresponds to te::Schedule::cache_read.
Definition: transform_step.h:989
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Map< tvm::te::Stage, Array< tir::IterVar >, ObjectHash, ObjectEqual > StageToAxesMap
Definition: transform_step.h:58
This iterator has been bind to threadIdx.x.
int iter_id
The index of the iterator to be factored.
Definition: transform_step.h:1130
Managed reference to CacheReadStepNode.
Definition: transform_step.h:1036
Managed reference to ComputeAtStepNode.
Definition: transform_step.h:848
int iter_id
The id of the iter to split.
Definition: transform_step.h:660
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:308
Managed reference to StateNode.
Definition: loop_state.h:272
Cache write step that corresponds to te::Schedule::cache_write.
Definition: transform_step.h:1062
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:58
String name
The name of this iterator.
Definition: transform_step.h:119
IteratorKind
The type of an iterator.
Definition: transform_step.h:68
This iterator has been bind to threadIdx.y.
Managed reference to ComputeInlineStepNode.
Definition: transform_step.h:905
Compute at step that corresponds to te::Stage::compute_at.
Definition: transform_step.h:804
This iterator has been bind to blockIdx.x.
base class of all object containers.
Definition: object.h:167
String StepPrintAsPythonAPI(const Step &step, Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule, const Array< Step > &transform_steps)
Print a general step as equivalent python schedule API with runtime dynamic dispatching.
String scope_name
The scope name of the newly added read stage. (e.g., local, shared, global)
Definition: transform_step.h:992
IteratorKind iter_kind
The iterator type of this iterator.
Definition: transform_step.h:123
int iter_id
The index of the iterator to add pragma.
Definition: transform_step.h:379
bool factor_or_nparts
If this is true, use factor. Otherwise, use nparts.
Definition: transform_step.h:666
int iter_id
The iterator to be aligned.
Definition: transform_step.h:741
Array< Optional< Integer > > lengths
The split factors.
Definition: transform_step.h:508
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
int level
Use the length in this split level.
Definition: transform_step.h:664
Similar to SplitStepNode, but uses split factors from another step (i.e. Follow another split step) ...
Definition: transform_step.h:578
String scope_name
The scope name of the newly added compute stage. (e.g. local, shared, global)
Definition: transform_step.h:1065
Range constainer.
Definition: expr.h:715
Similar to FollowSplitStep, but uses split factors from multiple steps.
Definition: transform_step.h:657
An iterator of a for-loop Similar to tvm::IterVar in include/tvm/tir/expr.h
Definition: transform_step.h:116
IteratorAnnotation
The type of an iterator's annotation.
Definition: transform_step.h:80
std::vector< Iterator > orig_iters
Definition: transform_step.h:127
Fused spatial and reduction iterator.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Reorder step that corresponds to te::Stage::reorder.
Definition: transform_step.h:437
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform_step.h:129
IteratorAnnotation annotation
The annotation type of this step.
Definition: transform_step.h:257
bool inner_to_outer
If true, the lengths denote the lengths of iterators from inner level to outer level.
Definition: transform_step.h:513
This iterator has been bind to threadIdx.y.
void UpdateStageToAxesMap(const te::Stage &stage, StageToAxesMap *stage_to_axes)
Update the current stage IterVar information to StageToAxesMap.
The base class of transformation steps. Each step has its corresponding tvm.te schedule primitives...
Definition: transform_step.h:164
Reference to string objects.
Definition: string.h:98
Managed reference to CacheWriteStepNode.
Definition: transform_step.h:1107
String pragma_type
The pragma string.
Definition: transform_step.h:381
Array< Integer > fused_ids
The ids of iterators to fuse.
Definition: transform_step.h:317
This iterator has been vectorized.
String-aware ObjectRef equal functor.
Definition: base.h:40
This iterator has been paralleld.
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Annotation step that corresponds to vectorize, parallel, unroll and thread binding. (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind)
Definition: transform_step.h:252
Compute root step that corresponds to te::Stage::compute_root.
Definition: transform_step.h:924
const char * IteratorAnnotationString[]
int iter_id
The id of the iter to split.
Definition: transform_step.h:504
This iterator has no annotation.
Array< Integer > reader_stage_ids
The indices of read stages.
Definition: transform_step.h:994
int offset
The offset in the alignment specification.
Definition: transform_step.h:745
Base class of all object reference.
Definition: object.h:511
int n_split
The number of split level.
Definition: transform_step.h:585
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
int target_stage_id
The index of stage that this step will compute at to.
Definition: transform_step.h:807
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Managed reference to FollowFusedSplitStepNode.
Definition: transform_step.h:714
This iterator has been bind to vthread.
Reduction factor step that corresponds to te::Schedule::rfactor.
Definition: transform_step.h:1127
Managed reference to RfactorStepNode.
Definition: transform_step.h:1174
int target_iter_id
The index of iterator in target stage that this step will compute at to.
Definition: transform_step.h:809
Step StepReadFromRecord(dmlc::JSONReader *reader)
Read a step record from JSONReader and create the corresponding step.
int iter_id
The index of the iterator to add annotation.
Definition: transform_step.h:255
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to IteratorNode.
Definition: transform_step.h:144
Split step that corresponds to te::Stage::split with additional support of multiple-level of factors...
Definition: transform_step.h:501
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
int src_step_id
The index of the split step to be followed in the history.
Definition: transform_step.h:583
int factor
The factor in alignment specification.
Definition: transform_step.h:743
Managed reference to ReorderStepNode.
Definition: transform_step.h:478
Managed reference to FuseStepNode.
Definition: transform_step.h:356
Compute inline step that corresponds to te::Stage::compute_inline.
Definition: transform_step.h:869
Storage align step that corresponds to te::Stage::storage_align.
Definition: transform_step.h:738
This iterator has been mapped with a tensorize intrinsic.
Special iterator. (e.g. virtual root iterator)
This iterator has been unrolled.
Managed reference to StorageAlignStepNode.
Definition: transform_step.h:780
Managed reference to StepNode.
Definition: transform_step.h:183
Managed reference to FollowSplitStepNode.
Definition: transform_step.h:633
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
Range range
The range of this iterator.
Definition: transform_step.h:121
This iterator has been bind to blockIdx.y.
Array< Integer > src_step_ids
The indices of the split steps to be followed in the history.
Definition: transform_step.h:662
Managed reference to PragmaStepNode.
Definition: transform_step.h:416
int stage_id
The index of the stage.
Definition: transform_step.h:167
This iterator has been bind to blockIdx.y.