tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_TIR_SCHEDULE_SCHEDULE_H_
21 
23 #include <tvm/tir/index_map.h>
24 #include <tvm/tir/schedule/state.h>
25 #include <tvm/tir/schedule/trace.h>
26 
27 namespace tvm {
28 namespace tir {
29 
31 enum class ScheduleErrorRenderLevel : int32_t {
33  kDetail = 0,
35  kFast = 1,
37  kNone = 2,
38 };
39 
41 enum class BufferIndexType : int32_t {
43  kRead = 0,
45  kWrite = 1,
46 };
47 
48 /**************** Random variable: BlockRV ****************/
49 
51 class BlockRVNode : public runtime::Object {
52  public:
54  static constexpr const char* _type_key = "tir.BlockRV";
56 };
57 
62 class BlockRV : public runtime::ObjectRef {
63  public:
65  TVM_DLL BlockRV();
67 };
68 
69 /**************** Random variable: LoopRV ****************/
70 
72 class LoopRVNode : public runtime::Object {
73  public:
75  static constexpr const char* _type_key = "tir.LoopRV";
77 };
78 
83 class LoopRV : public runtime::ObjectRef {
84  public:
86  TVM_DLL LoopRV();
88 };
89 
90 /**************** Random variable: ExprRV ****************/
91 
93 using ExprRV = PrimExpr;
94 
96 
97 /**************** The Schedule class ****************/
98 
99 class Schedule;
100 
103  friend class Schedule;
104 
105  public:
106  virtual ~ScheduleNode() = default;
107 
108  static constexpr const char* _type_key = "tir.Schedule";
110 
111  public:
113  virtual IRModule mod() const { return state()->mod; }
115  virtual ScheduleState state() const = 0;
117  virtual Optional<Trace> trace() const = 0;
132  virtual void WorkOn(const String& func_name) = 0;
141  virtual Schedule Copy() = 0;
146  virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
148  virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
149 
150  public:
151  /******** Lookup/Remove random variables ********/
157  virtual Block Get(const BlockRV& block_rv) const = 0;
163  virtual For Get(const LoopRV& loop_rv) const = 0;
169  virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
175  virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
181  virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
187  virtual bool HasBlock(const BlockRV& block_rv) const = 0;
193  virtual StmtSRef GetSRef(const StmtNode* stmt) const;
199  StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
204  virtual void RemoveRV(const BlockRV& block_rv) = 0;
209  virtual void RemoveRV(const LoopRV& loop_rv) = 0;
214  virtual void RemoveRV(const ExprRV& expr_rv) = 0;
215 
216  public:
217  /******** Schedule: Sampling ********/
225  virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
226  Optional<Integer> decision = NullOpt) = 0;
235  virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
236  Optional<Array<Integer>> decision = NullOpt) = 0;
243  virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
244  Optional<Integer> decision = NullOpt) = 0;
245 
246  /******** Schedule: Get blocks & loops ********/
261  virtual BlockRV GetBlock(const String& name, const Optional<String>& func_name = NullOpt) = 0;
267  virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
273  virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
279  virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
286  virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
293  virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
302  virtual Array<BlockRV> GetOutputBlocks(const BlockRV& scope_block_rv) = 0;
303  /******** Schedule: Transform loops ********/
313  virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
324  virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
335  virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
336  bool preserve_unit_iters = true) = 0;
349  virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
355  virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array<Integer> new_order) = 0;
361  virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
367  virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
368  /******** Schedule: Manipulate ForKind ********/
378  virtual void Parallel(const LoopRV& loop_rv) = 0;
388  virtual void Vectorize(const LoopRV& loop_rv) = 0;
400  virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
405  virtual void Unroll(const LoopRV& loop_rv) = 0;
406  /******** Schedule: Insert cache stages ********/
417  virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
418  const String& storage_scope,
419  const Array<BlockRV> consumer_blocks = {}) = 0;
430  virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
431  const String& storage_scope,
432  const Array<BlockRV> consumer_blocks = {}) = 0;
445  virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
446  const String& storage_scope, const IndexMap& index_map) = 0;
459  virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
460  const String& storage_scope, const IndexMap& index_map) = 0;
469  virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
470  const String& storage_scope) = 0;
479  virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
480  int cse_thresh) = 0;
492  virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
493  BufferIndexType buffer_index_type) = 0;
494  /******** Schedule: Data movement ********/
495  virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
496  const String& storage_scope) = 0;
497  virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
498  const String& storage_scope) = 0;
499  /******** Schedule: Compute location ********/
520  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
521  int index = -1) = 0;
541  virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
542  bool preserve_unit_loops, int index = -1) = 0;
553  virtual void ComputeInline(const BlockRV& block) = 0;
565  virtual void ReverseComputeInline(const BlockRV& block) = 0;
566  /******** Schedule: Reduction ********/
582  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
600  virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
601  /******** Schedule: Block annotation ********/
614  virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
615  int offset) = 0;
623  virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
633  virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0;
634  /******** Schedule: Blockize & Tensorize ********/
641  virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
648  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin,
649  bool preserve_unit_iters = true) = 0;
656  virtual void Tensorize(const BlockRV& block_rv, const String& intrin,
657  bool preserve_unit_iters = true) = 0;
658 
659  /******** Schedule: Annotation ********/
666  virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
673  virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
674  const ObjectRef& ann_val) = 0;
680  virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
686  virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
687 
688  /******** Schedule: Layout transformation ********/
720  virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
721  BufferIndexType buffer_index_type, const IndexMap& index_map,
722  const Optional<IndexMap>& pad_value = NullOpt,
723  bool assume_injective_transform = false) = 0;
724 
733  virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;
734 
743  virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
744  BufferIndexType buffer_index_type,
745  const Array<IntImm>& axis_separators) = 0;
746 
747  /******** Schedule: Padding ********/
755  virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
756 
774  virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;
775 
776  /******** Schedule: Buffer transformation ********/
791  virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;
792 
793  /******** Schedule: Misc ********/
795  virtual void EnterPostproc() = 0;
796 };
797 
813 class Schedule : public runtime::ObjectRef {
814  public:
829  int debug_mask, ScheduleErrorRenderLevel error_render_level,
830  bool enable_check = true);
845  TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
846  int debug_mask, ScheduleErrorRenderLevel error_render_level,
847  bool enable_check = true);
849 };
850 
851 } // namespace tir
852 } // namespace tvm
853 
854 #endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The unrolled For statement.
Base node of all statements.
Definition: stmt.h:38
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:53
Managed reference to BlockNode.
Definition: stmt.h:1258
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
Type Bind(const Type &type, const Map< TypeVar, Type > &args_map)
Bind free type variables in the type.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Managed reference to LoopRVNode.
Definition: schedule.h:83
Index of a written buffer.
Managed reference to ForNode.
Definition: stmt.h:962
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:199
Managed reference to StmtSRefNode.
Definition: block_scope.h:102
base class of all object containers.
Definition: object.h:167
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:31
Render a detailed error message.
Managed reference to ScheduleNode.
Definition: schedule.h:813
Defines a remapping of buffer indices.
int64_t TRandState
Definition: random_engine.h:46
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Definition: index_map.h:177
Container of all statements.
Definition: stmt.h:59
Reference to string objects.
Definition: string.h:98
Render the error in fast mode.
const Object * get() const
Definition: object.h:546
The user-facing schedule class.
Definition: schedule.h:102
Managed reference to ScheduleStateNode.
Definition: state.h:196
ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The parallel For statement.
Base class of all object reference.
Definition: object.h:511
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:113
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1391
This file defines ScheduleState, the core data structure of TensorIR scheduling.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Managed reference class to IRModuleNode.
Definition: module.h:348
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:72
Map< K, V > Merge(Map< K, V > lhs, const Map< K, V > &rhs)
Merge two Maps.
Definition: map.h:1471
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:51
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
BufferIndexType
Type of buffer index.
Definition: schedule.h:41
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:74
Reference to PrimExprNode.
Definition: expr.h:114
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
Index of a read buffer.
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Managed reference to BlockRVNode.
Definition: schedule.h:62
Base node of all primitive expressions.
Definition: expr.h:85