tvm
transform_step.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
45 #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
46 #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
47 
48 #include <dmlc/common.h>
49 #include <dmlc/json.h>
50 #include <tvm/node/node.h>
51 #include <tvm/te/schedule.h>
52 
53 #include <vector>
54 
55 namespace tvm {
56 namespace auto_scheduler {
57 
59 
65 void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
66 
68 enum class IteratorKind : int {
70  kSpatial = 0,
72  kReduction = 1,
74  kMixed = 2,
76  kSpecial = 3
77 };
78 
80 enum class IteratorAnnotation : int {
82  kNone = 0,
84  kUnroll = 1,
86  kVectorize = 2,
88  kParallel = 3,
90  kVThread = 4,
92  kBlockX = 5,
94  kThreadX = 6,
96  kBlockY = 7,
98  kThreadY = 8,
100  kBlockZ = 9,
102  kThreadZ = 10,
104  kTensorize = 11
105 };
106 
107 extern const char* IteratorAnnotationString[];
108 
109 // forward declaration
110 class Iterator;
111 
116 class IteratorNode : public Object {
117  public:
127  std::vector<Iterator> orig_iters;
128 
130  v->Visit("name", &name);
131  v->Visit("range", &range);
132  v->Visit("iter_kind", &iter_kind);
133  v->Visit("annotation", &annotation);
134  }
135 
136  static constexpr const char* _type_key = "auto_scheduler.Iterator";
138 };
139 
144 class Iterator : public ObjectRef {
145  public:
154  Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
155  const std::vector<Iterator>* orig_iters = nullptr);
156 
158 };
159 
164 class StepNode : public Object {
165  public:
167  int stage_id;
168 
173  virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0;
174 
175  static constexpr const char* _type_key = "auto_scheduler.Step";
177 };
178 
183 class Step : public ObjectRef {
184  public:
200 
202 };
203 
204 // Forward declaration
205 class State;
206 class ComputeDAG;
207 
212 Step StepReadFromRecord(dmlc::JSONReader* reader);
213 
220 void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
221 
230 void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
231  te::Schedule* schedule, const Array<Step>& transform_steps);
232 
243  StageToAxesMap* stage_to_axes, te::Schedule* schedule,
244  const Array<Step>& transform_steps);
245 
246 /********** Steps working on single stage **********/
247 
252 class AnnotationStepNode : public StepNode {
253  public:
255  int iter_id;
258 
259  void WriteToRecord(dmlc::JSONWriter* writer) const final;
260 
266  Iterator ApplyToState(State* state) const;
267 
273  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
274 
281  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
282 
283  static constexpr const char* record_prefix_str = "AN";
284 
285  static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
287 };
288 
293 class AnnotationStep : public Step {
294  public:
301  AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann);
302 
308  explicit AnnotationStep(dmlc::JSONReader* reader);
309 
311 };
312 
314 class FuseStepNode : public StepNode {
315  public:
318 
319  void WriteToRecord(dmlc::JSONWriter* writer) const final;
320 
328  Iterator ApplyToState(State* state) const;
329 
337 
344  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
345 
346  static constexpr const char* record_prefix_str = "FU";
347 
348  static constexpr const char* _type_key = "auto_scheduler.FuseStep";
350 };
351 
356 class FuseStep : public Step {
357  public:
363  FuseStep(int stage_id, const Array<Integer>& fused_ids);
364 
370  explicit FuseStep(dmlc::JSONReader* reader);
371 
373 };
374 
376 class PragmaStepNode : public StepNode {
377  public:
379  int iter_id;
382 
383  void WriteToRecord(dmlc::JSONWriter* writer) const final;
384 
389  void ApplyToState(State* state) const;
390 
396  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
397 
404  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
405 
406  static constexpr const char* record_prefix_str = "PR";
407 
408  static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
410 };
411 
416 class PragmaStep : public Step {
417  public:
424  PragmaStep(int stage_id, int iter_id, String pragma_type);
425 
431  explicit PragmaStep(dmlc::JSONReader* reader);
432 
434 };
435 
437 class ReorderStepNode : public StepNode {
438  public:
444 
445  void WriteToRecord(dmlc::JSONWriter* writer) const final;
446 
451  void ApplyToState(State* state) const;
452 
458  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
459 
466  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
467 
468  static constexpr const char* record_prefix_str = "RE";
469 
470  static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
472 };
473 
478 class ReorderStep : public Step {
479  public:
485  ReorderStep(int stage_id, const Array<Integer>& after_ids);
486 
492  explicit ReorderStep(dmlc::JSONReader* reader);
493 
495 };
496 
501 class SplitStepNode : public StepNode {
502  public:
504  int iter_id;
514 
515  void WriteToRecord(dmlc::JSONWriter* writer) const final;
516 
525 
533  StageToAxesMap* stage_to_axes) const;
534 
541  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
542 
543  static constexpr const char* record_prefix_str = "SP";
544 
545  static constexpr const char* _type_key = "auto_scheduler.SplitStep";
547 };
548 
553 class SplitStep : public Step {
554  public:
563  SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
564  const Array<Optional<Integer>>& lengths, bool inner_to_outer);
565 
571  explicit SplitStep(dmlc::JSONReader* reader);
572 
574 };
575 
579  public:
581  int iter_id;
585  int n_split;
586 
587  void WriteToRecord(dmlc::JSONWriter* writer) const final;
588 
595 
602 
611  const Array<Step>& transform_steps) const;
612 
621  const Array<Step>& transform_steps) const;
622 
623  static constexpr const char* record_prefix_str = "FSP";
624 
625  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
627 };
628 
633 class FollowSplitStep : public Step {
634  public:
642  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
643 
649  explicit FollowSplitStep(dmlc::JSONReader* reader);
650 
652 };
653 
658  public:
660  int iter_id;
664  int level;
667 
668  void WriteToRecord(dmlc::JSONWriter* writer) const final;
669 
675  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
676 
683 
692  const Array<Step>& transform_steps) const;
693 
702  const Array<Step>& transform_steps) const;
703 
704  static constexpr const char* record_prefix_str = "FFSP";
705 
706  static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
708 };
709 
714 class FollowFusedSplitStep : public Step {
715  public:
724  FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
725  bool factor_or_nparts);
726 
732  explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
733 
735 };
736 
739  public:
741  int iter_id;
743  int factor;
745  int offset;
746 
747  void WriteToRecord(dmlc::JSONWriter* writer) const final;
748 
753  void ApplyToState(State* state) const;
754 
760  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
761 
768  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
769 
770  static constexpr const char* record_prefix_str = "SA";
771 
772  static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
774 };
775 
780 class StorageAlignStep : public Step {
781  public:
789  StorageAlignStep(int stage_id, int iter_id, int factor, int offset);
790 
796  explicit StorageAlignStep(dmlc::JSONReader* reader);
797 
799 };
800 
801 /********** Steps working on multiple stages **********/
802 
804 class ComputeAtStepNode : public StepNode {
805  public:
810 
811  void WriteToRecord(dmlc::JSONWriter* writer) const final;
812 
821  void ApplyToState(State* state) const;
822 
828  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
829 
836  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
837 
838  static constexpr const char* record_prefix_str = "CA";
839 
840  static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
842 };
843 
848 class ComputeAtStep : public Step {
849  public:
856  ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id);
857 
863  explicit ComputeAtStep(dmlc::JSONReader* reader);
864 
866 };
867 
870  public:
871  void WriteToRecord(dmlc::JSONWriter* writer) const final;
872 
877  void ApplyToState(State* state) const;
878 
884  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
885 
892  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
893 
894  static constexpr const char* record_prefix_str = "CI";
895 
896  static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
898 };
899 
904 class ComputeInlineStep : public Step {
905  public:
910  explicit ComputeInlineStep(int stage_id);
911 
917  explicit ComputeInlineStep(dmlc::JSONReader* reader);
918 
920 };
921 
924  public:
925  void WriteToRecord(dmlc::JSONWriter* writer) const final;
926 
935  void ApplyToState(State* state) const;
936 
942  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
943 
950  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
951 
952  static constexpr const char* record_prefix_str = "CR";
953 
954  static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
956 };
957 
962 class ComputeRootStep : public Step {
963  public:
968  explicit ComputeRootStep(int stage_id);
969 
975  explicit ComputeRootStep(dmlc::JSONReader* reader);
976 
978 };
979 
980 /********** Steps adding new stages **********/
981 
987 class CacheReadStepNode : public StepNode {
988  public:
993 
994  void WriteToRecord(dmlc::JSONWriter* writer) const final;
995 
1002  int ApplyToState(State* state, const ComputeDAG& dag) const;
1003 
1012  te::Schedule* schedule) const;
1013 
1022  te::Schedule* schedule) const;
1023 
1024  static constexpr const char* record_prefix_str = "CHR";
1025 
1026  static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
1028 };
1029 
1034 class CacheReadStep : public Step {
1035  public:
1042  CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
1043 
1049  explicit CacheReadStep(dmlc::JSONReader* reader);
1050 
1052 };
1053 
1061  public:
1064 
1065  void WriteToRecord(dmlc::JSONWriter* writer) const final;
1066 
1073  int ApplyToState(State* state, const ComputeDAG& dag) const;
1074 
1083  te::Schedule* schedule) const;
1084 
1093  te::Schedule* schedule) const;
1094 
1095  static constexpr const char* record_prefix_str = "CHW";
1096 
1097  static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
1099 };
1100 
1105 class CacheWriteStep : public Step {
1106  public:
1112  CacheWriteStep(int stage_id, String scope_name);
1113 
1119  explicit CacheWriteStep(dmlc::JSONReader* reader);
1120 
1122 };
1123 
1125 class RfactorStepNode : public StepNode {
1126  public:
1128  int iter_id;
1131 
1132  void WriteToRecord(dmlc::JSONWriter* writer) const final;
1133 
1140  int ApplyToState(State* state, const ComputeDAG& dag) const;
1141 
1150  te::Schedule* schedule) const;
1151 
1160  te::Schedule* schedule) const;
1161 
1162  static constexpr const char* record_prefix_str = "RF";
1163 
1164  static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
1166 };
1167 
1172 class RfactorStep : public Step {
1173  public:
1180  RfactorStep(int stage_id, int iter_id, int factor_iter_id);
1181 
1187  explicit RfactorStep(dmlc::JSONReader* reader);
1188 
1190 };
1191 
1192 } // namespace auto_scheduler
1193 } // namespace tvm
1194 
1195 #endif // TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Range container
Definition: expr.h:725
Annotation step that corresponds to vectorize, parallel, unroll and thread binding....
Definition: transform_step.h:252
static constexpr const char * record_prefix_str
Definition: transform_step.h:283
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode)
Iterator ApplyToState(State *state) const
Apply the current step to State.
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
IteratorAnnotation annotation
The annotation type of this step.
Definition: transform_step.h:257
int iter_id
The index of the iterator to add annotation.
Definition: transform_step.h:255
static constexpr const char * _type_key
Definition: transform_step.h:285
Managed reference to AnnotationStepNode.
Definition: transform_step.h:293
AnnotationStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode)
AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann)
The constructor.
Cache read step that corresponds to te::Schedule::cache_read.
Definition: transform_step.h:987
Array< Integer > reader_stage_ids
The indices of read stages.
Definition: transform_step.h:992
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Print the current step as equivalent python schedule API.
int ApplyToState(State *state, const ComputeDAG &dag) const
Apply the current step to State.
String scope_name
The scope name of the newly added read stage. (e.g., local, shared, global)
Definition: transform_step.h:990
static constexpr const char * _type_key
Definition: transform_step.h:1026
static constexpr const char * record_prefix_str
Definition: transform_step.h:1024
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode)
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
te::Tensor ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Apply the current step to tvm.schedule.
Managed reference to CacheReadStepNode.
Definition: transform_step.h:1034
CacheReadStep(int stage_id, String scope_name, const Array< Integer > &reader_stage_ids)
The constructor.
CacheReadStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode)
Cache write step that corresponds to te::Schedule::cache_write.
Definition: transform_step.h:1060
int ApplyToState(State *state, const ComputeDAG &dag) const
Apply the current step to State.
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode)
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Print the current step as equivalent python schedule API.
Array< te::Tensor > ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Apply the current step to tvm.schedule.
static constexpr const char * _type_key
Definition: transform_step.h:1097
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
static constexpr const char * record_prefix_str
Definition: transform_step.h:1095
String scope_name
The scope name of the newly added compute stage. (e.g. local, shared, global)
Definition: transform_step.h:1063
Managed reference to CacheWriteStepNode.
Definition: transform_step.h:1105
CacheWriteStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
CacheWriteStep(int stage_id, String scope_name)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode)
Compute at step that corresponds to te::Stage::compute_at.
Definition: transform_step.h:804
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode)
int target_iter_id
The index of iterator in target stage that this step will compute at to.
Definition: transform_step.h:809
static constexpr const char * _type_key
Definition: transform_step.h:840
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
static constexpr const char * record_prefix_str
Definition: transform_step.h:838
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
int target_stage_id
The index of stage that this step will compute at to.
Definition: transform_step.h:807
void ApplyToState(State *state) const
Apply the current step to State.
Managed reference to ComputeAtStepNode.
Definition: transform_step.h:848
ComputeAtStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode)
ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id)
The constructor.
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
Compute inline step that corresponds to te::Stage::compute_inline.
Definition: transform_step.h:869
static constexpr const char * record_prefix_str
Definition: transform_step.h:894
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode)
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
void ApplyToState(State *state) const
Apply the current step to State.
static constexpr const char * _type_key
Definition: transform_step.h:896
Managed reference to ComputeInlineStepNode.
Definition: transform_step.h:904
ComputeInlineStep(int stage_id)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode)
ComputeInlineStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
Compute root step that corresponds to te::Stage::compute_root.
Definition: transform_step.h:923
static constexpr const char * _type_key
Definition: transform_step.h:954
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode)
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
void ApplyToState(State *state) const
Apply the current step to State.
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
static constexpr const char * record_prefix_str
Definition: transform_step.h:952
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
Managed reference to ComputeRootStepNode.
Definition: transform_step.h:962
TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode)
ComputeRootStep(int stage_id)
The constructor.
ComputeRootStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
Similar to FollowSplitStep, but uses split factors from multiple steps.
Definition: transform_step.h:657
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode)
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, const Array< Step > &transform_steps) const
Print the current step as equivalent python schedule API.
Optional< Integer > ExtractSplitLength(const Array< Step > &transform_steps) const
Extract split length.
int level
Use the length in this split level.
Definition: transform_step.h:664
Array< tir::IterVar > ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, const Array< Step > &transform_steps) const
Apply the current step to tvm.schedule.
static constexpr const char * _type_key
Definition: transform_step.h:706
static constexpr const char * record_prefix_str
Definition: transform_step.h:704
Array< Iterator > ApplyToState(State *state) const
Apply the current step to State.
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
Array< Integer > src_step_ids
The indices of the split steps to be followed in the history.
Definition: transform_step.h:662
int iter_id
The id of the iter to split.
Definition: transform_step.h:660
bool factor_or_nparts
If this is true, use factor. Otherwise, use nparts.
Definition: transform_step.h:666
Managed reference to FollowFusedSplitStepNode.
Definition: transform_step.h:714
FollowFusedSplitStep(int stage_id, int iter_id, const Array< Integer > &src_step_ids, int level, bool factor_or_nparts)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode)
FollowFusedSplitStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
Similar to SplitStepNode, but uses split factors from another step (i.e. Follow another split step)
Definition: transform_step.h:578
int src_step_id
The index of the split step to be followed in the history.
Definition: transform_step.h:583
int iter_id
The id of the iter to be split.
Definition: transform_step.h:581
int n_split
The number of split level.
Definition: transform_step.h:585
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, const Array< Step > &transform_steps) const
Print the current step as equivalent python schedule API.
static constexpr const char * _type_key
Definition: transform_step.h:625
static constexpr const char * record_prefix_str
Definition: transform_step.h:623
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode)
Array< tir::IterVar > ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, const Array< Step > &transform_steps) const
Apply the current step to tvm.schedule.
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
Array< Iterator > ApplyToState(State *state) const
Apply the current step to State.
Array< Optional< Integer > > ExtractSplitLengths(const Array< Step > &transform_steps) const
Extract split lengths.
Managed reference to FollowSplitStepNode.
Definition: transform_step.h:633
FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode)
FollowSplitStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
Fuse step that corresponds to te::Stage::fuse.
Definition: transform_step.h:314
Array< Integer > fused_ids
The ids of iterators to fuse.
Definition: transform_step.h:317
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
Iterator ApplyToState(State *state) const
Apply the current step to State.
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode)
tir::IterVar ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
static constexpr const char * record_prefix_str
Definition: transform_step.h:346
static constexpr const char * _type_key
Definition: transform_step.h:348
Managed reference to FuseStepNode.
Definition: transform_step.h:356
FuseStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
FuseStep(int stage_id, const Array< Integer > &fused_ids)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode)
An iterator of a for-loop Similar to tvm::IterVar in include/tvm/tir/expr.h
Definition: transform_step.h:116
Range range
The range of this iterator.
Definition: transform_step.h:121
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform_step.h:129
IteratorAnnotation annotation
The annotation type of this iterator.
Definition: transform_step.h:125
IteratorKind iter_kind
The iterator type of this iterator.
Definition: transform_step.h:123
TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object)
static constexpr const char * _type_key
Definition: transform_step.h:136
String name
The name of this iterator.
Definition: transform_step.h:119
std::vector< Iterator > orig_iters
Definition: transform_step.h:127
Managed reference to IteratorNode.
Definition: transform_step.h:144
TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode)
Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation, const std::vector< Iterator > *orig_iters=nullptr)
The constructor.
Pragma step that corresponds to te::Stage::pragma.
Definition: transform_step.h:376
static constexpr const char * _type_key
Definition: transform_step.h:408
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode)
void ApplyToState(State *state) const
Apply the current step to State.
String pragma_type
The pragma string.
Definition: transform_step.h:381
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
int iter_id
The index of the iterator to add pragma.
Definition: transform_step.h:379
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
static constexpr const char * record_prefix_str
Definition: transform_step.h:406
Managed reference to PragmaStepNode.
Definition: transform_step.h:416
TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode)
PragmaStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
PragmaStep(int stage_id, int iter_id, String pragma_type)
The constructor.
Reorder step that corresponds to te::Stage::reorder.
Definition: transform_step.h:437
static constexpr const char * record_prefix_str
Definition: transform_step.h:468
void ApplyToState(State *state) const
Apply the current step to State.
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode)
Array< Integer > after_ids
The iterator ids after reorder. This array should specify the order of all iterators.
Definition: transform_step.h:443
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
static constexpr const char * _type_key
Definition: transform_step.h:470
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
Managed reference to ReorderStepNode.
Definition: transform_step.h:478
TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode)
ReorderStep(int stage_id, const Array< Integer > &after_ids)
The constructor.
ReorderStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
Reduction factor step that corresponds to te::Schedule::rfactor.
Definition: transform_step.h:1125
int iter_id
The index of the iterator to be factored.
Definition: transform_step.h:1128
int factor_iter_id
The position where the new iterator is placed.
Definition: transform_step.h:1130
static constexpr const char * record_prefix_str
Definition: transform_step.h:1162
int ApplyToState(State *state, const ComputeDAG &dag) const
Apply the current step to State.
Array< te::Tensor > ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Apply the current step to tvm.schedule.
static constexpr const char * _type_key
Definition: transform_step.h:1164
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule) const
Print the current step as equivalent python schedule API.
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode)
Managed reference to RfactorStepNode.
Definition: transform_step.h:1172
RfactorStep(int stage_id, int iter_id, int factor_iter_id)
The constructor.
RfactorStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode)
Split step that corresponds to te::Stage::split with additional support of multiple-level of factors.
Definition: transform_step.h:501
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
Array< Iterator > ApplyToState(State *state) const
Apply the current step to State.
Optional< PrimExpr > extent
The extent length of the axis to split.
Definition: transform_step.h:506
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
int iter_id
The id of the iter to split.
Definition: transform_step.h:504
static constexpr const char * _type_key
Definition: transform_step.h:545
bool inner_to_outer
If true, the lengths denote the lengths of iterators from inner level to outer level.
Definition: transform_step.h:513
static constexpr const char * record_prefix_str
Definition: transform_step.h:543
Array< Optional< Integer > > lengths
The split factors.
Definition: transform_step.h:508
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode)
Array< tir::IterVar > ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
Managed reference to SplitStepNode.
Definition: transform_step.h:553
SplitStep(int stage_id, int iter_id, Optional< PrimExpr > extent, const Array< Optional< Integer >> &lengths, bool inner_to_outer)
The constructor.
SplitStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode)
Managed reference to StateNode.
Definition: loop_state.h:272
The base class of transformation steps. Each step has its corresponding tvm.te schedule primitives.
Definition: transform_step.h:164
static constexpr const char * _type_key
Definition: transform_step.h:175
virtual void WriteToRecord(dmlc::JSONWriter *writer) const =0
Serialize the current step record to JSONWriter.
TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object)
int stage_id
The index of the stage.
Definition: transform_step.h:167
Managed reference to StepNode.
Definition: transform_step.h:183
StepNode * CopyOnWrite()
CopyOnWrite function for Step. This works almost the same as a normal ObjectRef.CopyOnWrite(),...
TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode)
Storage align step that corresponds to te::Stage::storage_align.
Definition: transform_step.h:738
static constexpr const char * record_prefix_str
Definition: transform_step.h:770
int offset
The offset in the alignment specification.
Definition: transform_step.h:745
void ApplyToSchedule(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Apply the current step to tvm.schedule.
int factor
The factor in alignment specification.
Definition: transform_step.h:743
int iter_id
The iterator to be aligned.
Definition: transform_step.h:741
void WriteToRecord(dmlc::JSONWriter *writer) const final
Serialize the current step record to JSONWriter.
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode)
String PrintAsPythonAPI(Array< te::Stage > *stages, StageToAxesMap *stage_to_axes) const
Print the current step as equivalent python schedule API.
void ApplyToState(State *state) const
Apply the current step to State.
static constexpr const char * _type_key
Definition: transform_step.h:772
Managed reference to StorageAlignStepNode.
Definition: transform_step.h:780
TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode)
StorageAlignStep(dmlc::JSONReader *reader)
The constructor used to read a step record from JSONReader and create the corresponding step.
StorageAlignStep(int stage_id, int iter_id, int factor, int offset)
The constructor.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:58
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
IteratorKind
The type of an iterator.
Definition: transform_step.h:68
@ kSpecial
Special iterator. (e.g. virtual root iterator)
@ kReduction
Reduction iterator.
@ kMixed
Fused spatial and reduction iterator.
String StepPrintAsPythonAPI(const Step &step, Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule, const Array< Step > &transform_steps)
Print a general step as equivalent python schedule API with runtime dynamic dispatching.
void StepApplyToState(const Step &step, State *state, const ComputeDAG &dag)
Apply a general step to a State with runtime dynamic dispatching.
const char * IteratorAnnotationString[]
Map< tvm::te::Stage, Array< tir::IterVar >, ObjectHash, ObjectEqual > StageToAxesMap
Definition: transform_step.h:58
Step StepReadFromRecord(dmlc::JSONReader *reader)
Read a step record from JSONReader and create the corresponding step.
void StepApplyToSchedule(const Step &step, Array< te::Stage > *stages, StageToAxesMap *stage_to_axes, te::Schedule *schedule, const Array< Step > &transform_steps)
Apply a general step to tvm.schedule with runtime dynamic dispatching.
IteratorAnnotation
The type of an iterator's annotation.
Definition: transform_step.h:80
@ kTensorize
This iterator has been mapped with a tensorize intrinsic.
@ kBlockX
This iterator has been bind to blockIdx.x.
@ kNone
This iterator has no annotation.
@ kVectorize
This iterator has been vectorized.
@ kParallel
This iterator has been paralleld.
@ kUnroll
This iterator has been unrolled.
@ kThreadY
This iterator has been bind to threadIdx.y.
@ kThreadX
This iterator has been bind to threadIdx.x.
@ kBlockZ
This iterator has been bind to blockIdx.y.
@ kBlockY
This iterator has been bind to blockIdx.y.
@ kThreadZ
This iterator has been bind to threadIdx.y.
@ kVThread
This iterator has been bind to vthread.
void UpdateStageToAxesMap(const te::Stage &stage, StageToAxesMap *stage_to_axes)
Update the current stage IterVar information to StageToAxesMap.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definitions and helper macros for IR/AST nodes.
String-aware ObjectRef hash functor.
Definition: base.h:50
String-aware ObjectRef equal functor.
Definition: base.h:40
Define a schedule.