tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_TE_SCHEDULE_H_
26 #define TVM_TE_SCHEDULE_H_
27 
28 #include <tvm/support/with.h>
29 #include <tvm/te/tensor.h>
30 #include <tvm/te/tensor_intrin.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/index_map.h>
33 
34 #include <string>
35 #include <unordered_map>
36 
37 namespace tvm {
38 namespace te {
39 // Node container for Stage
40 class StageNode;
41 // Node container for Schedule
42 class ScheduleNode;
43 // Node container for IterVarRelation
44 class IterVarRelationNode;
45 // Attribute of itervar.
46 class IterVarAttrNode;
47 
49 enum AttachType : int {
51  kInline = 2,
53  kScope = 4,
54  kScanUpdate = 5
55 };
56 
58 class Stage : public ObjectRef {
59  public:
60  Stage() {}
61  explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
67  explicit Stage(Operation op, const ScheduleNode* sch);
72  inline const StageNode* operator->() const;
77  inline StageNode* operator->();
82  TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
89  TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
94  TVM_DLL Stage& compute_inline(); // NOLINT(*)
99  TVM_DLL Stage& compute_root(); // NOLINT(*)
107  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
118  TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
127  TVM_DLL Stage& env_threads(Array<IterVar> threads);
136  TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
137  IterVar* p_inner); // NOLINT(*)
147  TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
148  IterVar* p_inner); // NOLINT(*)
156  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
170  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
176  TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
192  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
193  PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
194  IterVar* p_x_inner, IterVar* p_y_inner);
200  TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
208  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
214  TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
220  TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
230  TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
231  const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
239  TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*)
250  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*)
255  TVM_DLL Stage& double_buffer(); // NOLINT(*)
260  TVM_DLL Stage& rolling_buffer(); // NOLINT(*)
285  TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
286  const Array<PrimExpr>& final_indices,
287  Array<IterVar>* out_iter_vars = nullptr);
300  bool is_scheduled() const;
309  // declare container type
311 };
312 
318 class Schedule : public ObjectRef {
319  public:
320  Schedule() {}
326  TVM_DLL explicit Schedule(Array<Operation> ops);
331  Schedule copy() const;
336  TVM_DLL Stage operator[](const Operation& op);
342  TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
352  TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
353  bool include_inputs = false);
363  TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
364  const Array<Operation>& readers);
381  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
398  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
412  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
422 
433 
438  inline const ScheduleNode* operator->() const;
444  // declare container type
446 };
447 
456  private:
457  friend class With<ScheduleContext>;
458  ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name);
459  void EnterWithScope();
460  void ExitWithScope();
461 
463  Schedule sch_;
465  String current_primitive_name_;
466 };
467 
472 class IterVarRelation : public ObjectRef {
473  public:
480  inline const IterVarRelationNode* operator->() const;
481 };
482 
486 class IterVarAttr : public ObjectRef {
487  public:
494  inline const IterVarAttrNode* operator->() const;
495 };
496 
512 class StageNode : public Object {
513  public:
572  std::string scope;
574  bool is_output{false};
576  bool double_buffer{false};
578  bool rolling_buffer{false};
594 
596  v->Visit("op", &op);
597  v->Visit("origin_op", &origin_op);
598  v->Visit("all_iter_vars", &all_iter_vars);
599  v->Visit("leaf_iter_vars", &leaf_iter_vars);
600  v->Visit("env_threads", &env_threads);
601  v->Visit("relations", &relations);
602  v->Visit("iter_var_attrs", &iter_var_attrs);
603  v->Visit("attach_type", &attach_type);
604  v->Visit("attach_ivar", &attach_ivar);
605  v->Visit("attach_stage", &attach_stage);
606  v->Visit("scope", &scope);
607  v->Visit("is_output", &is_output);
608  v->Visit("double_buffer", &double_buffer);
609  v->Visit("layout_transforms", &layout_transforms);
610  v->Visit("axis_separators", &axis_separators);
611  v->Visit("group", &group);
612  v->Visit("num_child_stages", &num_child_stages);
613  }
614 
615  static constexpr const char* _type_key = "Stage";
617 };
618 
620 class ScheduleNode : public Object {
621  public:
639  std::unordered_map<const Object*, Stage> op2stage_cache_;
655 
657  v->Visit("outputs", &outputs);
658  v->Visit("stages", &stages);
659  v->Visit("groups", &groups);
660  v->Visit("stage_map", &stage_map);
661  v->Visit("schedule_record", &schedule_record);
662  v->Visit("primitive_record", &primitive_record);
663  v->Visit("keep_schedule_record", &keep_schedule_record);
664  }
665 
667  void InitCache();
670 
676  TVM_DLL bool Contain(const Operation& op) const;
677 
683  TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
684 
685  static constexpr const char* _type_key = "Schedule";
687 };
688 
694 inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
695 
697 class IterVarAttrNode : public Object {
698  public:
724 
726  v->Visit("iter_type", &iter_type);
727  v->Visit("bind_thread", &bind_thread);
728  v->Visit("prefetch_data", &prefetch_data);
729  v->Visit("prefetch_offset", &prefetch_offset);
730  v->Visit("tensor_intrin", &tensor_intrin);
731  v->Visit("dim_align_factor", &dim_align_factor);
732  v->Visit("dim_align_offset", &dim_align_offset);
733  v->Visit("pragma_keys", &pragma_keys);
734  v->Visit("pragma_values", &pragma_values);
735  }
736 
737  static constexpr const char* _type_key = "IterVarAttr";
739 };
740 
742 class IterVarRelationNode : public Object {
743  public:
744  static constexpr const char* _type_key = "IterVarRelation";
746 };
747 
753  public:
764 
766  v->Visit("parent", &parent);
767  v->Visit("outer", &outer);
768  v->Visit("inner", &inner);
769  v->Visit("factor", &factor);
770  v->Visit("nparts", &nparts);
771  }
772 
773  static constexpr const char* _type_key = "Split";
775 };
776 
781 class Split : public IterVarRelation {
782  public:
783  TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
784 
786 };
787 
792  public:
799 
801  v->Visit("outer", &outer);
802  v->Visit("inner", &inner);
803  v->Visit("fused", &fused);
804  }
805 
806  static constexpr const char* _type_key = "Fuse";
808 };
809 
814 class Fuse : public IterVarRelation {
815  public:
816  TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
817 
819 };
820 
827  public:
832 
834  v->Visit("parent", &parent);
835  v->Visit("rebased", &rebased);
836  }
837 
838  static constexpr const char* _type_key = "Rebase";
840 };
841 
846 class Rebase : public IterVarRelation {
847  public:
848  TVM_DLL Rebase(IterVar parent, IterVar rebased);
849 
851 };
852 
857  public:
860 
861  void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
862 
863  static constexpr const char* _type_key = "Singleton";
865 };
866 
871 class Singleton : public IterVarRelation {
872  public:
873  TVM_DLL explicit Singleton(IterVar iter);
874 
876 };
877 
882  public:
891 
899 
905 
912 
914  v->Visit("original_variables", &original_variables);
915  v->Visit("transformed_variables", &transformed_variables);
916  v->Visit("forward_transformation", &forward_transformation);
917  v->Visit("inverse_transformation", &inverse_transformation);
918  }
919 
920  static constexpr const char* _type_key = "Transform";
922 };
923 
924 class Transform : public IterVarRelation {
925  public:
926  TVM_DLL explicit Transform(Array<IterVar> original_variables,
927  Array<IterVar> transformed_variables, IndexMap forward_transformation,
928  IndexMap inverse_transformation);
929 
931 };
932 
935  public:
942 
943  void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
944 
945  static constexpr const char* _type_key = "SpecializedCondition";
947 };
948 
953  public:
958  TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
959 
964  TVM_DLL static SpecializedCondition Current();
965 
967  class Internal;
968 
969  private:
970  // enable with syntax.
971  friend class Internal;
972  friend class With<SpecializedCondition>;
974  TVM_DLL void EnterWithScope();
976  TVM_DLL void ExitWithScope();
977 };
978 
979 // implementations
980 inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
981 inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
982 
983 inline const ScheduleNode* Schedule::operator->() const {
984  return static_cast<const ScheduleNode*>(get());
985 }
986 inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
987 
989  return static_cast<const IterVarRelationNode*>(get());
990 }
991 
993  return static_cast<const IterVarAttrNode*>(get());
994 }
995 
996 } // namespace te
997 } // namespace tvm
998 #endif // TVM_TE_SCHEDULE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:114
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
const Object * get() const
Definition: object.h:552
Object * get_mutable() const
Definition: object.h:605
base class of all object containers.
Definition: object.h:169
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Fuse two domains into one domain.
Definition: schedule.h:791
IterVar inner
The inner domain.
Definition: schedule.h:796
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:800
IterVar outer
The outer domain.
Definition: schedule.h:794
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode)
IterVar fused
The target domain.
Definition: schedule.h:798
static constexpr const char * _type_key
Definition: schedule.h:806
Managed reference to FuseNode.
Definition: schedule.h:814
Fuse(IterVar outer, IterVar inner, IterVar fused)
TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode)
node container for IterVar attr
Definition: schedule.h:697
Array< PrimExpr > pragma_keys
Additional pragma keys, array of StringImm.
Definition: schedule.h:719
Array< Tensor > prefetch_data
List of tensor to be prefetched in this loop.
Definition: schedule.h:704
IterVarType iter_type
The iteration type.
Definition: schedule.h:700
Array< PrimExpr > prefetch_offset
The offset used in each prefetch.
Definition: schedule.h:706
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object)
TensorIntrin tensor_intrin
Tensor intrinsic used in tensorization, when the axis is marked as Tensorized.
Definition: schedule.h:711
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:725
static constexpr const char * _type_key
Definition: schedule.h:737
IterVar bind_thread
The thread this iter Var binds, can be null.
Definition: schedule.h:702
int dim_align_factor
Alignment factor of buffer dimension.
Definition: schedule.h:713
int dim_align_offset
Alignment offset of buffer dimension.
Definition: schedule.h:715
Array< PrimExpr > pragma_values
Additional values of pragma, if any.
Definition: schedule.h:723
Additional scheduable attributes about IterVar.
Definition: schedule.h:486
IterVarAttr()
Definition: schedule.h:488
const IterVarAttrNode * operator->() const
access the internal node container
Definition: schedule.h:992
IterVarAttr(ObjectPtr< Object > n)
Definition: schedule.h:489
base node of iteration var
Definition: schedule.h:742
static constexpr const char * _type_key
Definition: schedule.h:744
TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object)
The schedule relation between IterVars can be Split, Fuse.
Definition: schedule.h:472
IterVarRelation(ObjectPtr< Object > n)
Definition: schedule.h:475
IterVarRelation()
Definition: schedule.h:474
const IterVarRelationNode * operator->() const
access the internal node container
Definition: schedule.h:988
Operation that produces tensors.
Definition: tensor.h:47
Rebase the iteration to make min to be 0. This is useful to normalize the Schedule to make every leaf...
Definition: schedule.h:826
IterVar rebased
The inner domain.
Definition: schedule.h:831
IterVar parent
The parent domain.
Definition: schedule.h:829
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:833
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode)
static constexpr const char * _type_key
Definition: schedule.h:838
Managed reference to RebaseNode.
Definition: schedule.h:846
TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode)
Rebase(IterVar parent, IterVar rebased)
Context helper to collect debug information of Schedule.
Definition: schedule.h:455
node container for schedule
Definition: schedule.h:620
Array< Operation > outputs
The output operations in original data flow graph.
Definition: schedule.h:623
void InvalidateCache()
Invalidate temp cache.
void InitCache()
Initialize temp cache.
bool Contain(const Operation &op) const
Check if the schedule contains an Operation.
bool Contain(const Tensor &tensor) const
Check if the schedule contains a Tensor.
Definition: schedule.h:683
Array< Schedule > schedule_record
list of all transformed schedules User can display the optimization strategy via TEDD step by step to...
Definition: schedule.h:646
Map< Operation, Stage > stage_map
map of original operation to the stages
Definition: schedule.h:634
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object)
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:656
static constexpr const char * _type_key
Definition: schedule.h:685
Array< Stage > groups
List of all stage groups.
Definition: schedule.h:632
Optional< Bool > keep_schedule_record
Flag to keep schedule record or not.
Definition: schedule.h:654
Array< Stage > stages
list of all stages for ops. The stages are sorted in dependency order.
Definition: schedule.h:628
std::unordered_map< const Object *, Stage > op2stage_cache_
Internal stage map to map internal ops to stages. This is created on demand and can be invalidated.
Definition: schedule.h:639
Array< String > primitive_record
list of all applied primitive names.
Definition: schedule.h:650
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
Tensor cache_write(const Tensor &tensor, const std::string &scope)
Create a cache write tensor for producing tensor. The tensor will take over body of original tensor o...
Schedule normalize_for_feature_extraction()
Normalize the schedule for feature extraction in auto-scheduler. This is similar to Schedule::normali...
Schedule()
Definition: schedule.h:320
Stage operator[](const Tensor &tensor)
Short hand for getting the stage of tensor's operation.
Definition: schedule.h:342
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
Tensor cache_read(const Tensor &tensor, const std::string &scope, const Array< Operation > &readers)
create a cache read of original tensor for readers. This will mutate the body of the readers....
Stage operator[](const Operation &op)
Get the stage corresponds to the op.
Schedule normalize()
Normalize the schedule. This is needed before bound inference. Insert necessary RebaseNode to make su...
const ScheduleNode * operator->() const
access the internal node container
Definition: schedule.h:983
Schedule(ObjectPtr< Object > n)
Definition: schedule.h:321
Stage create_group(const Array< Tensor > &outputs, const Array< Tensor > &inputs, bool include_inputs=false)
Create a new stage group for all intermediate operations between inputs and outputs.
Schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Schedule copy() const
Get a copy of current schedule.
Array< Tensor > cache_write(const Array< Tensor > &tensor, const std::string &scope)
Create a cache write tensor for producing tensor. The tensor will take over body of original tensor o...
Singleton iterator [0, 1)
Definition: schedule.h:856
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:861
IterVar iter
The singleton iterator.
Definition: schedule.h:859
static constexpr const char * _type_key
Definition: schedule.h:863
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode)
Managed reference to SingletonNode.
Definition: schedule.h:871
TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode)
Singleton(IterVar iter)
Container for specialization conditions.
Definition: schedule.h:934
static constexpr const char * _type_key
Definition: schedule.h:945
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:943
Array< PrimExpr > clauses
List of conditions in conjunctive joint form (CNF). Each condition should be a simple expression,...
Definition: schedule.h:941
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object)
Specialized condition to enable op specialization.
Definition: schedule.h:952
SpecializedCondition(Array< PrimExpr > conditions)
construct from conditions
static SpecializedCondition Current()
Get the current specialized condition.
friend class Internal
Definition: schedule.h:967
TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode)
Split the parent domain into product of outer and iter.
Definition: schedule.h:752
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:765
PrimExpr nparts
Number of parts, only factor or nparts can be given.
Definition: schedule.h:763
IterVar inner
The inner domain.
Definition: schedule.h:759
PrimExpr factor
The split factor.
Definition: schedule.h:761
IterVar outer
The outer domain.
Definition: schedule.h:757
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode)
static constexpr const char * _type_key
Definition: schedule.h:773
IterVar parent
The parent domain.
Definition: schedule.h:755
Managed reference to SplitNode.
Definition: schedule.h:781
Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts)
TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode)
represents a stage.
Definition: schedule.h:512
Stage group
The parent group of the current stage. The stage cannot be assigned to stages outside the group.
Definition: schedule.h:591
const ScheduleNode * attach_sch
The schedule current stage is attached to.
Definition: schedule.h:570
Map< IterVar, IterVarAttr > iter_var_attrs
additional attributes about iter var.
Definition: schedule.h:562
AttachType attach_type
The attachment type of the schedule.
Definition: schedule.h:564
Operation op
The operation of stage, can be different from original op. If it is null, then this stage is a group ...
Definition: schedule.h:518
Operation origin_op
The original operator. The op field can change during schedule to alternate the dataflow,...
Definition: schedule.h:524
TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object)
Array< IntImm > axis_separators
List of axes after which to divide physical axes.
Definition: schedule.h:586
std::string scope
The thread storage scope level of the stage.
Definition: schedule.h:572
Stage attach_stage
The stage this node attaches to.
Definition: schedule.h:568
Array< IterVar > leaf_iter_vars
The current active leaf iter vars in the stage.
Definition: schedule.h:546
bool rolling_buffer
Whether apply rolling buffer optimization to this stage.
Definition: schedule.h:578
PrimExpr store_predicate
The predicate under which store can happen Use this when there can be duplicated threads doing the sa...
Definition: schedule.h:558
int num_child_stages
Number of direct child stages, only used for group stage.
Definition: schedule.h:593
Array< IndexMap > layout_transforms
Layout transformations to be applied onto the stage's tensors.
Definition: schedule.h:580
Array< IterVar > env_threads
Specify threads to be launched at the stage. This is only valid for composite ops such as Scan.
Definition: schedule.h:552
Array< IterVarRelation > relations
The relation bwteen of IterVars.
Definition: schedule.h:560
IterVar attach_ivar
The attach point of this schedule.
Definition: schedule.h:566
Array< IterVar > all_iter_vars
All the nodes in the iter var.
Definition: schedule.h:538
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:595
static constexpr const char * _type_key
Definition: schedule.h:615
bool double_buffer
Whether apply double buffer optimization to this stage.
Definition: schedule.h:576
bool is_output
Whether this is an output stage.
Definition: schedule.h:574
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:58
Stage & set_store_predicate(PrimExpr predicate)
Set the predicate to determine whether a store to the array should be performed. Use this when there ...
Stage & compute_at(Stage parent, IterVar scope)
specify the schedule to be computed at the parent schedule's scope.
Stage & fuse(const Array< IterVar > &axes, IterVar *p_target)
Fuse all the axes together into a single axis.
Stage & double_buffer()
Compute current stage with double buffering.
bool is_scheduled() const
whether the stage has been scheduled.
Stage & set_scope(std::string scope)
set the memory scope of the stage
Stage & compute_inline()
Compute the function inline.
Stage & vectorize(IterVar var)
Vectorize iteration.
Stage(Operation op, const ScheduleNode *sch)
create a new schedule for op.
Stage & split_by_nparts(IterVar parent, PrimExpr nparts, IterVar *p_outer, IterVar *p_inner)
Split the iteration with given number of parts.
Stage & fuse(IterVar outer, IterVar inner, IterVar *p_target)
Fuse the inner outer domain to the target.
Stage & split(IterVar parent, PrimExpr factor, IterVar *p_outer, IterVar *p_inner)
Split the parent by factor, generate.
Stage & parallel(IterVar var)
Parallelize iteration.
Stage & prefetch(const Tensor &domain, IterVar var, PrimExpr offset)
Fetch data in advance.
Stage GetAttachSpec() const
Get attachment spec of current stage. If the stage compute at Group root, this function will traverse...
Stage & pragma(IterVar var, const std::string &pragma_type, const PrimExpr &pragma_value=PrimExpr())
Annotate the iteration with pragma.
Stage & tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, IterVar *p_x_outer, IterVar *p_y_outer, IterVar *p_x_inner, IterVar *p_y_inner)
Perform tiling on two dimensions The final loop order from outmost to inner most are [x_outer,...
const StageNode * operator->() const
access the internal node container
Definition: schedule.h:980
Stage & compute_root()
Compute the function at group root.
Stage & rolling_buffer()
Compute current stage with rolling buffering.
Stage(ObjectPtr< Object > n)
Definition: schedule.h:61
Stage & storage_align(IterVar axis, int factor, int offset)
Set alignment requirement for specific dimension.
Stage & bind(IterVar ivar, IterVar thread_ivar)
Bind the IterVar to thread index.
Stage & tensorize(IterVar var, TensorIntrin f)
Replace computation of the current stage by tensor intrinsic f.
Stage & env_threads(Array< IterVar > threads)
Specify environment threads that launched around the group's scope. This can only be used in group st...
Stage & transform_layout(const Array< Var > &initial_indices, const Array< PrimExpr > &final_indices, Array< IterVar > *out_iter_vars=nullptr)
Defines a layout transformation to be applied to the buffer.
Stage & reorder(const Array< IterVar > &order)
Reorder the iteration.
Stage & set_axis_separators(const Array< IntImm > &axis_separators)
Defines separators between groups of axes.
Stage & unroll(IterVar var)
Unroll iteration.
Stage()
Definition: schedule.h:60
Managed reference to TensorIntrinNode.
Definition: tensor_intrin.h:93
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Transform iterator according to some arbitrary expression.
Definition: schedule.h:881
Array< IterVar > transformed_variables
The variables generated by the transformation.
Definition: schedule.h:898
IndexMap forward_transformation
Map from the original variables to the transformed variables.
Definition: schedule.h:904
static constexpr const char * _type_key
Definition: schedule.h:920
IndexMap inverse_transformation
Map from transformed variables to the original variables.
Definition: schedule.h:911
Array< IterVar > original_variables
The loop variables that were replaced by the transformation.
Definition: schedule.h:890
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:913
TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode)
Definition: schedule.h:924
Transform(Array< IterVar > original_variables, Array< IterVar > transformed_variables, IndexMap forward_transformation, IndexMap inverse_transformation)
TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode)
Definition: index_map.h:176
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:314
Defines a remapping of buffer indices.
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:694
AttachType
the attachment type
Definition: schedule.h:49
@ kScanUpdate
Definition: schedule.h:54
@ kInline
Definition: schedule.h:51
@ kScope
Definition: schedule.h:53
@ kGroupRoot
Definition: schedule.h:50
@ kInlinedAlready
Definition: schedule.h:52
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1454
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:191
@ kDataPar
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:200
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Dataflow tensor object.
Tensor intrinsic operations.
TIR expressions.
RAII wrapper function to enter and exit a context object similar to python's with syntax.