tvm
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_TE_SCHEDULE_H_
26 #define TVM_TE_SCHEDULE_H_
27 
28 #include <tvm/support/with.h>
29 #include <tvm/te/tensor.h>
30 #include <tvm/te/tensor_intrin.h>
31 #include <tvm/tir/expr.h>
32 
33 #include <string>
34 #include <unordered_map>
35 
36 namespace tvm {
37 namespace te {
38 // Node container for Stage
39 class StageNode;
40 // Node container for Schedule
41 class ScheduleNode;
42 // Node container for IterVarRelation
43 class IterVarRelationNode;
44 // Attribute of itervar.
45 class IterVarAttrNode;
46 
48 enum AttachType : int {
50  kInline = 2,
52  kScope = 4,
54 };
55 
57 class Stage : public ObjectRef {
58  public:
59  Stage() {}
60  explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
65  explicit Stage(Operation op);
70  inline const StageNode* operator->() const;
75  inline StageNode* operator->();
80  TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
87  TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
92  TVM_DLL Stage& compute_inline(); // NOLINT(*)
97  TVM_DLL Stage& compute_root(); // NOLINT(*)
105  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
116  TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
125  TVM_DLL Stage& env_threads(Array<IterVar> threads);
134  TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
135  IterVar* p_inner); // NOLINT(*)
145  TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
146  IterVar* p_inner); // NOLINT(*)
154  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
168  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
174  TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
190  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
191  PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
192  IterVar* p_x_inner, IterVar* p_y_inner);
198  TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
206  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
212  TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
218  TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
228  TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
229  const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
237  TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*)
248  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*)
253  TVM_DLL Stage& double_buffer(); // NOLINT(*)
258  bool is_scheduled() const;
266  Stage GetAttachSpec() const;
267  // declare container type
269 };
270 
276 class Schedule : public ObjectRef {
277  public:
278  Schedule() {}
285  TVM_DLL explicit Schedule(Array<Operation> ops);
290  Schedule copy() const;
295  TVM_DLL Stage operator[](const Operation& op);
301  TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
311  TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
312  bool include_inputs = false);
322  TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
323  const Array<Operation>& readers);
340  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
357  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
371  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
380  Schedule normalize();
381 
391  Schedule normalize_for_feature_extraction();
392 
397  inline const ScheduleNode* operator->() const;
402  inline ScheduleNode* operator->();
403  // declare container type
405 };
406 
411 class IterVarRelation : public ObjectRef {
412  public:
419  inline const IterVarRelationNode* operator->() const;
420 };
421 
425 class IterVarAttr : public ObjectRef {
426  public:
433  inline const IterVarAttrNode* operator->() const;
434 };
435 
451 class StageNode : public Object {
452  public:
485  AttachType attach_type{kGroupRoot};
491  std::string scope;
493  bool is_output{false};
495  bool double_buffer{false};
502  int num_child_stages{0};
503 
505  v->Visit("op", &op);
506  v->Visit("origin_op", &origin_op);
507  v->Visit("all_iter_vars", &all_iter_vars);
508  v->Visit("leaf_iter_vars", &leaf_iter_vars);
509  v->Visit("env_threads", &env_threads);
510  v->Visit("relations", &relations);
511  v->Visit("iter_var_attrs", &iter_var_attrs);
512  v->Visit("attach_type", &attach_type);
513  v->Visit("attach_ivar", &attach_ivar);
514  v->Visit("attach_stage", &attach_stage);
515  v->Visit("scope", &scope);
516  v->Visit("is_output", &is_output);
517  v->Visit("double_buffer", &double_buffer);
518  v->Visit("group", &group);
519  v->Visit("num_child_stages", &num_child_stages);
520  }
521 
522  static constexpr const char* _type_key = "Stage";
524 };
525 
527 class ScheduleNode : public Object {
528  public:
546  std::unordered_map<const Object*, Stage> op2stage_cache_;
547 
549  v->Visit("outputs", &outputs);
550  v->Visit("stages", &stages);
551  v->Visit("groups", &groups);
552  v->Visit("stage_map", &stage_map);
553  }
554 
556  void InitCache();
558  void InvalidateCache();
559 
565  TVM_DLL bool Contain(const Operation& op) const;
566 
572  TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
573 
574  static constexpr const char* _type_key = "Schedule";
576 };
577 
583 inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
584 
586 class IterVarAttrNode : public Object {
587  public:
589  IterVarType iter_type{kDataPar};
602  int dim_align_factor{0};
604  int dim_align_offset{0};
613 
615  v->Visit("iter_type", &iter_type);
616  v->Visit("bind_thread", &bind_thread);
617  v->Visit("prefetch_data", &prefetch_data);
618  v->Visit("prefetch_offset", &prefetch_offset);
619  v->Visit("tensor_intrin", &tensor_intrin);
620  v->Visit("dim_align_factor", &dim_align_factor);
621  v->Visit("dim_align_offset", &dim_align_offset);
622  v->Visit("pragma_keys", &pragma_keys);
623  v->Visit("pragma_values", &pragma_values);
624  }
625 
626  static constexpr const char* _type_key = "IterVarAttr";
628 };
629 
631 class IterVarRelationNode : public Object {
632  public:
633  static constexpr const char* _type_key = "IterVarRelation";
635 };
636 
642  public:
653 
655  v->Visit("parent", &parent);
656  v->Visit("outer", &outer);
657  v->Visit("inner", &inner);
658  v->Visit("factor", &factor);
659  v->Visit("nparts", &nparts);
660  }
661 
662  static constexpr const char* _type_key = "Split";
664 };
665 
670 class Split : public IterVarRelation {
671  public:
672  TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
673 
675 };
676 
681  public:
688 
690  v->Visit("outer", &outer);
691  v->Visit("inner", &inner);
692  v->Visit("fused", &fused);
693  }
694 
695  static constexpr const char* _type_key = "Fuse";
697 };
698 
703 class Fuse : public IterVarRelation {
704  public:
705  TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
706 
708 };
709 
716  public:
721 
723  v->Visit("parent", &parent);
724  v->Visit("rebased", &rebased);
725  }
726 
727  static constexpr const char* _type_key = "Rebase";
729 };
730 
735 class Rebase : public IterVarRelation {
736  public:
737  TVM_DLL Rebase(IterVar parent, IterVar rebased);
738 
740 };
741 
746  public:
749 
750  void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
751 
752  static constexpr const char* _type_key = "Singleton";
754 };
755 
760 class Singleton : public IterVarRelation {
761  public:
762  TVM_DLL explicit Singleton(IterVar iter);
763 
765 };
766 
769  public:
776 
777  void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
778 
779  static constexpr const char* _type_key = "SpecializedCondition";
781 };
782 
787  public:
792  TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
793 
798  TVM_DLL static SpecializedCondition Current();
799 
801  class Internal;
802 
803  private:
804  // enable with syntax.
805  friend class Internal;
806  friend class With<SpecializedCondition>;
808  TVM_DLL void EnterWithScope();
810  TVM_DLL void ExitWithScope();
811 };
812 
813 // implementations
814 inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
815 inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
816 
817 inline const ScheduleNode* Schedule::operator->() const {
818  return static_cast<const ScheduleNode*>(get());
819 }
820 inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
821 
823  return static_cast<const IterVarRelationNode*>(get());
824 }
825 
827  return static_cast<const IterVarAttrNode*>(get());
828 }
829 
830 } // namespace te
831 } // namespace tvm
832 #endif // TVM_TE_SCHEDULE_H_
Array< Stage > groups
List of all stage groups.
Definition: schedule.h:539
Stage()
Definition: schedule.h:59
Split the parent domain into product of outer and iter.
Definition: schedule.h:641
Stage & compute_root()
Compute the function at group root.
Managed reference to RebaseNode.
Definition: schedule.h:735
represents a stage.
Definition: schedule.h:451
A custom smart pointer for Object.
Definition: object.h:356
IterVar bind_thread
The thread this iter Var binds, can be null.
Definition: schedule.h:591
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:276
IterVar outer
The outer domain.
Definition: schedule.h:683
Stage operator[](const Tensor &tensor)
Short hand for getting the stage of tensor&#39;s operation.
Definition: schedule.h:301
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:504
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:583
Stage(ObjectPtr< Object > n)
Definition: schedule.h:60
Array< Operation > outputs
The output operations in original data flow graph.
Definition: schedule.h:530
Definition: schedule.h:53
const IterVarRelationNode * operator->() const
access the internal node container
Definition: schedule.h:822
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:548
Fuse two domains into one domain.
Definition: schedule.h:680
Stage attach_stage
The stage this node attaches to.
Definition: schedule.h:489
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
IterVar fused
The target domain.
Definition: schedule.h:687
IterVar inner
The inner domain.
Definition: schedule.h:648
Schedule()
Definition: schedule.h:278
Operation that produces tensors.
Definition: tensor.h:47
IterVar inner
The inner domain.
Definition: schedule.h:685
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:722
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
PrimExpr nparts
Number of parts, only factor or nparts can be given.
Definition: schedule.h:652
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:57
IterVar attach_ivar
The attach point of this schedule.
Definition: schedule.h:487
Stage & compute_at(Stage parent, IterVar scope)
specify the schedule to be computed at the parent schedule&#39;s scope.
AttachType
the attachment type
Definition: schedule.h:48
Specialized condition to enable op specialization.
Definition: schedule.h:786
base class of all object containers.
Definition: object.h:165
std::string scope
The thread storage scope level of the stage.
Definition: schedule.h:491
Array< IterVar > all_iter_vars
All the nodes in the iter var.
Definition: schedule.h:465
Stage & set_scope(std::string scope)
set the memory scope of the stage
Operation op
The operation of stage, can be different from original op. If it is null, then this stage is a group ...
Definition: schedule.h:457
Definition: schedule.h:50
Managed reference to SplitNode.
Definition: schedule.h:670
Stage & set_store_predicate(PrimExpr predicate)
Set the predicate to determine whether a store to the array should be performed. Use this when there ...
IterVarAttr()
Definition: schedule.h:427
bool is_scheduled() const
whether the stage has been scheduled.
Stage & double_buffer()
Compute current stage with double buffering.
Container for specialization conditions.
Definition: schedule.h:768
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
TensorIntrin tensor_intrin
Tensor intrinsic used in tensorization, when the axis is marked as Tensorized.
Definition: schedule.h:600
Dataflow tensor object.
Stage & unroll(IterVar var)
Unroll iteration.
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:178
Rebase the iteration to make min to be 0. This is useful to normalize the Schedule to make every leaf...
Definition: schedule.h:715
Stage group
The parent group of the current stage. The stage cannot be assigned to stages outside the group...
Definition: schedule.h:500
The schedule relation between IterVars can be Split, Fuse.
Definition: schedule.h:411
Managed reference to TensorIntrinNode.
Definition: tensor_intrin.h:93
TIR expressions.
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:689
IterVarRelation(ObjectPtr< Object > n)
Definition: schedule.h:414
Definition: schedule.h:49
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Stage & reorder(const Array< IterVar > &order)
Reorder the iteration.
Array< Tensor > prefetch_data
List of tensor to be prefetched in this loop.
Definition: schedule.h:593
const StageNode * operator->() const
access the internal node container
Definition: schedule.h:814
node container for IterVar attr
Definition: schedule.h:586
Managed reference to SingletonNode.
Definition: schedule.h:760
const IterVarAttrNode * operator->() const
access the internal node container
Definition: schedule.h:826
Stage & compute_inline()
Compute the function inline.
Array< IterVar > env_threads
Specify threads to be launched at the stage. This is only valid for composite ops such as Scan...
Definition: schedule.h:473
Stage & tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, IterVar *p_x_outer, IterVar *p_y_outer, IterVar *p_x_inner, IterVar *p_y_inner)
Perform tiling on two dimensions The final loop order from outmost to inner most are [x_outer...
const ScheduleNode * operator->() const
access the internal node container
Definition: schedule.h:817
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:57
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:614
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:187
Array< IterVar > leaf_iter_vars
The current active leaf iter vars in the stage.
Definition: schedule.h:467
Definition: schedule.h:51
IterVar parent
The parent domain.
Definition: schedule.h:644
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Base class of all object reference.
Definition: object.h:504
Singleton iterator [0, 1)
Definition: schedule.h:745
Stage & parallel(IterVar var)
Parallelize iteration.
Definition: schedule.h:52
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Array< PrimExpr > pragma_keys
Additional pragma keys, array of StringImm.
Definition: schedule.h:608
Managed reference to FuseNode.
Definition: schedule.h:703
PrimExpr factor
The split factor.
Definition: schedule.h:650
Stage & env_threads(Array< IterVar > threads)
Specify environment threads that launched around the group&#39;s scope. This can only be used in group st...
Stage & vectorize(IterVar var)
Vectorize iteration.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
Tensor intrinsic operations.
Schedule(ObjectPtr< Object > n)
Definition: schedule.h:279
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:750
Stage & bind(IterVar ivar, IterVar thread_ivar)
Bind the IterVar to thread index.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Stage & pragma(IterVar var, const std::string &pragma_type, const PrimExpr &pragma_value=PrimExpr())
Annotate the iteration with pragma.
Map< IterVar, IterVarAttr > iter_var_attrs
additional attributes about iter var.
Definition: schedule.h:483
PrimExpr store_predicate
The predicate under which store can happen Use this when there can be duplicated threads doing the sa...
Definition: schedule.h:479
Array< IterVarRelation > relations
The relation bwteen of IterVars.
Definition: schedule.h:481
Stage GetAttachSpec() const
Get attachment spec of current stage. If the stage compute at Group root, this function will traverse...
IterVar rebased
The inner domain.
Definition: schedule.h:720
IterVarRelation()
Definition: schedule.h:413
Stage & split(IterVar parent, PrimExpr factor, IterVar *p_outer, IterVar *p_inner)
Split the parent by factor, generate.
IterVar outer
The outer domain.
Definition: schedule.h:646
Operation origin_op
The original operator. The op field can change during schedule to alternate the dataflow, while origin_op remains fixed.
Definition: schedule.h:463
Stage & fuse(IterVar outer, IterVar inner, IterVar *p_target)
Fuse the inner outer domain to the target.
IterVarAttr(ObjectPtr< Object > n)
Definition: schedule.h:428
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:777
base node of iteration var
Definition: schedule.h:631
Reference to PrimExprNode.
Definition: expr.h:109
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:654
node container for schedule
Definition: schedule.h:527
std::unordered_map< const Object *, Stage > op2stage_cache_
Internal stage map to map internal ops to stages. This is created on demand and can be invalidated...
Definition: schedule.h:546
Stage & storage_align(IterVar axis, int factor, int offset)
Set alignment requirement for specific dimension.
bool Contain(const Tensor &tensor) const
Check if the schedule contains a Tensor.
Definition: schedule.h:572
Object * get_mutable() const
Definition: object.h:569
Stage & split_by_nparts(IterVar parent, PrimExpr nparts, IterVar *p_outer, IterVar *p_inner)
Split the iteration with given number of parts.
Array< PrimExpr > prefetch_offset
The offset used in each prefetch.
Definition: schedule.h:595
Array< Stage > stages
list of all stages for ops. The stages are sorted in dependency order.
Definition: schedule.h:535
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
Additional scheduable attributes about IterVar.
Definition: schedule.h:425
Stage & tensorize(IterVar var, TensorIntrin f)
Replace computation of the current stage by tensor intrinsic f.
Array< PrimExpr > pragma_values
Additional values of pragma, if any.
Definition: schedule.h:612
Map< Operation, Stage > stage_map
map of original operation to the stages
Definition: schedule.h:541
Array< PrimExpr > clauses
List of conditions in conjunctive joint form (CNF). Each condition should be a simple expression...
Definition: schedule.h:775
IterVar iter
The singleton iterator.
Definition: schedule.h:748
Stage & prefetch(const Tensor &domain, IterVar var, PrimExpr offset)
Fetch data in advance.
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
IterVar parent
The parent domain.
Definition: schedule.h:718