tvm
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_TIR_SCHEDULE_SCHEDULE_H_
21 
23 #include <tvm/tir/index_map.h>
24 #include <tvm/tir/schedule/state.h>
25 #include <tvm/tir/schedule/trace.h>
26 
27 namespace tvm {
28 namespace tir {
29 
31 enum class ScheduleErrorRenderLevel : int32_t {
33  kDetail = 0,
35  kFast = 1,
37  kNone = 2,
38 };
39 
41 enum class BufferIndexType : int32_t {
43  kRead = 0,
45  kWrite = 1,
46 };
47 
48 /**************** Random variable: BlockRV ****************/
49 
51 class BlockRVNode : public runtime::Object {
52  public:
54  static constexpr const char* _type_key = "tir.BlockRV";
56 };
57 
62 class BlockRV : public runtime::ObjectRef {
63  public:
65  TVM_DLL BlockRV();
67 };
68 
69 /**************** Random variable: LoopRV ****************/
70 
72 class LoopRVNode : public runtime::Object {
73  public:
75  static constexpr const char* _type_key = "tir.LoopRV";
77 };
78 
83 class LoopRV : public runtime::ObjectRef {
84  public:
86  TVM_DLL LoopRV();
88 };
89 
90 /**************** Random variable: ExprRV ****************/
91 
93 using ExprRV = PrimExpr;
94 
96 
97 /**************** The Schedule class ****************/
98 
99 class Schedule;
100 
103  friend class Schedule;
104 
105  public:
106  virtual ~ScheduleNode() = default;
107 
108  static constexpr const char* _type_key = "tir.Schedule";
110 
111  public:
113  virtual IRModule mod() const { return state()->mod; }
115  virtual ScheduleState state() const = 0;
117  virtual Optional<Trace> trace() const = 0;
134  virtual void WorkOn(const String& func_name) = 0;
143  virtual Schedule Copy() = 0;
151 
152  public:
153  /******** Lookup/Remove random variables ********/
159  virtual Block Get(const BlockRV& block_rv) const = 0;
165  virtual For Get(const LoopRV& loop_rv) const = 0;
171  virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
177  virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
183  virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
189  virtual bool HasBlock(const BlockRV& block_rv) const = 0;
195  virtual StmtSRef GetSRef(const StmtNode* stmt) const;
201  StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
206  virtual void RemoveRV(const BlockRV& block_rv) = 0;
211  virtual void RemoveRV(const LoopRV& loop_rv) = 0;
216  virtual void RemoveRV(const ExprRV& expr_rv) = 0;
217 
218  public:
219  /******** Schedule: Sampling ********/
227  virtual ExprRV SampleCategorical(const Array<runtime::Int>& candidates,
228  const Array<runtime::Float>& probs,
229  Optional<runtime::Int> decision = NullOpt) = 0;
238  virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
239  Optional<Array<Integer>> decision = NullOpt) = 0;
255  virtual Array<ExprRV> SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos,
256  int innerpart_factor,
257  Optional<Array<Integer>> decision = NullOpt) = 0;
264  virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
265  Optional<Integer> decision = NullOpt) = 0;
266 
267  /******** Schedule: Get blocks & loops ********/
282  virtual BlockRV GetBlock(const String& name, const Optional<String>& func_name = NullOpt) = 0;
288  virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
294  virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
300  virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
307  virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
314  virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
323  virtual Array<BlockRV> GetOutputBlocks(const BlockRV& scope_block_rv) = 0;
324  /******** Schedule: Transform loops ********/
334  virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
345  virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
360  virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
361  bool preserve_unit_iters = true,
362  bool disable_predication = false) = 0;
372  virtual Array<LoopRV> LoopPartition(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
373  bool preserve_unit_iters = true) = 0;
386  virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
392  virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array<Integer> new_order) = 0;
398  virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
404  virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
405  /******** Schedule: Manipulate ForKind ********/
415  virtual void Parallel(const LoopRV& loop_rv) = 0;
425  virtual void Vectorize(const LoopRV& loop_rv) = 0;
437  virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
442  virtual void Unroll(const LoopRV& loop_rv) = 0;
443  /******** Schedule: Insert cache stages ********/
454  virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
455  const String& storage_scope,
456  const Array<BlockRV> consumer_blocks = {}) = 0;
467  virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
468  const String& storage_scope,
469  const Array<BlockRV> consumer_blocks = {}) = 0;
482  virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
483  const String& storage_scope, const IndexMap& index_map) = 0;
496  virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
497  const String& storage_scope, const IndexMap& index_map) = 0;
506  virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
507  const String& storage_scope) = 0;
516  virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
517  int cse_thresh) = 0;
529  virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
530  BufferIndexType buffer_index_type) = 0;
531  /******** Schedule: Data movement ********/
532  virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
533  const String& storage_scope) = 0;
534  virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
535  const String& storage_scope) = 0;
536  /******** Schedule: Compute location ********/
557  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
558  int index = -1) = 0;
578  virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
579  bool preserve_unit_loops, int index = -1) = 0;
590  virtual void ComputeInline(const BlockRV& block) = 0;
602  virtual void ReverseComputeInline(const BlockRV& block) = 0;
603  /******** Schedule: Reduction ********/
619  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
637  virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
638  /******** Schedule: Block annotation ********/
651  virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
652  int offset) = 0;
660  virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
670  virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0;
671  /******** Schedule: Blockize & Tensorize ********/
678  virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
685  virtual BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters = true) = 0;
692  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin,
693  bool preserve_unit_iters = true) = 0;
700  virtual void Tensorize(const BlockRV& block_rv, const String& intrin,
701  bool preserve_unit_iters = true) = 0;
702 
703  /******** Schedule: Annotation ********/
710  virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
717  virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
718  const ObjectRef& ann_val) = 0;
724  virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
730  virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
731 
732  /******** Schedule: Layout transformation ********/
764  virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
765  BufferIndexType buffer_index_type, const IndexMap& index_map,
766  const Optional<IndexMap>& pad_value = NullOpt,
767  bool assume_injective_transform = false) = 0;
768 
777  virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;
778 
787  virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
788  BufferIndexType buffer_index_type,
789  const Array<IntImm>& axis_separators) = 0;
790 
791  /******** Schedule: Padding ********/
799  virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
800 
818  virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;
819 
820  /******** Schedule: Buffer transformation ********/
835  virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;
836 
844  virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
845  BufferIndexType buffer_index_type,
846  const IndexMap& index_map) = 0;
847 
848  /******** Schedule: Misc ********/
850  virtual void EnterPostproc() = 0;
851 
858  virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type,
859  const Array<IntImm>& buf_index_array) = 0;
860 };
861 
877 class Schedule : public runtime::ObjectRef {
878  public:
893  int debug_mask, ScheduleErrorRenderLevel error_render_level,
894  bool enable_check = true);
910  int debug_mask, ScheduleErrorRenderLevel error_render_level,
911  bool enable_check = true);
913 };
914 
915 } // namespace tir
916 } // namespace tvm
917 
918 #endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Base node of all primitive expressions.
Definition: expr.h:86
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
int64_t TRandState
Definition: random_engine.h:46
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:51
static constexpr const char * _type_key
Definition: schedule.h:54
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:53
Managed reference to BlockRVNode.
Definition: schedule.h:62
BlockRV()
Constructor.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode)
Managed reference to BlockNode.
Definition: stmt.h:1325
Managed reference to ForNode.
Definition: stmt.h:1029
Definition: index_map.h:176
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:72
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:74
static constexpr const char * _type_key
Definition: schedule.h:75
TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object)
Managed reference to LoopRVNode.
Definition: schedule.h:83
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode)
LoopRV()
Constructor.
The user-facing schedule class.
Definition: schedule.h:102
virtual void RemoveRV(const ExprRV &expr_rv)=0
Remove an integer random variable from the symbol table.
virtual void SetAxisSeparator(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array< IntImm > &axis_separators)=0
Set the axis separator of a buffer, where the buffer is specified by a block and a read or write inde...
virtual void Reorder(const Array< LoopRV > &ordered_loop_rvs)=0
Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops ar...
virtual Array< BlockRV > GetChildBlocks(const LoopRV &loop_rv)=0
Get the leaf blocks of under a specific loop.
virtual StmtSRef GetSRef(const LoopRV &loop_rv) const =0
Get the loop sref corresponding to the specific LoopRV.
virtual BlockRV DecomposeReduction(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a reduction block into two separate blocks. a) The init block, which is translated from the...
virtual LoopRV AddUnitLoop(const BlockRV &block_rv)=0
Create a new unit loop on top of the specific block.
virtual LoopRV Merge(const Array< LoopRV > &loop_rvs)=0
Merge a list of loops into one. The loops under their LCA requires: 1) Under the same scope 2) Can't ...
virtual void PadEinsum(const BlockRV &block_rv, const Array< Integer > &padding)=0
Pad the computation of Einsum.
virtual PrimExpr Get(const ExprRV &expr_rv) const =0
Get the expr corresponding to the specific random variable.
virtual Array< BlockRV > CacheIndex(const BlockRV &block_rv, const String &storage_scope, int cse_thresh)=0
Create a block to cache precomputed index for later use. if there is no index computation,...
virtual Array< LoopRV > GetLoops(const BlockRV &block_rv)=0
Get the parent loops of the block in its scope, from outer to inner.
virtual void EnterPostproc()=0
A no-op that marks the start of postprocessing phase of scheduling.
virtual BlockRV ReindexCacheRead(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope, const IndexMap &index_map)=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:201
virtual Array< ExprRV > SamplePerfectTile(const LoopRV &loop_rv, int n, int max_innermost_factor, Optional< Array< Integer >> decision=NullOpt)=0
Sample the factors to perfect tile a specific loop.
virtual BlockRV ReadAt(const LoopRV &loop_rv, const BlockRV &block_rv, int read_buffer_index, const String &storage_scope)=0
virtual LoopRV AddUnitLoop(const LoopRV &loop_rv)=0
Create a new unit loop on top of the specific loop.
virtual StmtSRef GetSRef(const StmtNode *stmt) const
Get the block/loop sref corresponding to the specific statement.
virtual Array< LoopRV > LoopPartition(const LoopRV &loop_rv, const Array< Optional< ExprRV >> &factors, bool preserve_unit_iters=true)=0
Partition the loops into sequence of multiple loops 1) The loop can't have annotation or thread bindi...
virtual StmtSRef GetSRef(const BlockRV &block_rv) const =0
Get the block sref corresponding to the specific BlockRV.
virtual void ReorderBlockIterVar(const BlockRV &block_rv, const Array< Integer > new_order)=0
Reorder the itervars inside a block.
virtual void Seed(support::LinearCongruentialEngine::TRandState seed)=0
Seed the randomness.
virtual Optional< GlobalVar > func_working_on() const =0
virtual Array< BlockRV > GetProducers(const BlockRV &block_rv)=0
Get the producer of a specific block, under the same block scope.
virtual void Unannotate(const LoopRV &loop_rv, const String &ann_key)=0
Unannotate a loop's annotation with key ann_key.
virtual void Parallel(const LoopRV &loop_rv)=0
Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pip...
virtual BlockRV CacheRead(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope, const Array< BlockRV > consumer_blocks={})=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual LoopRV Fuse(const Array< LoopRV > &loop_rvs, bool preserve_unit_iters=true)=0
Fuse a list of consecutive loops into one. It requires: 1) The loops can't have annotations or thread...
virtual void ComputeInline(const BlockRV &block)=0
Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block,...
virtual void Tensorize(const LoopRV &loop_rv, const String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual Array< ExprRV > SamplePartitionedTile(const LoopRV &loop_rv, int n, int partition_pos, int innerpart_factor, Optional< Array< Integer >> decision=NullOpt)=0
Sample the factors to a partitioned tile for a specific loop.
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:113
virtual void RemoveRV(const BlockRV &block_rv)=0
Remove a block random variable from the symbol table.
virtual Schedule Copy()=0
Returns a copy of the schedule, including both its state and its symbol table, guaranteeing that 1) S...
virtual void Unannotate(const BlockRV &block_rv, const String &ann_key)=0
Unannotate a block's annotation with key ann_key.
virtual void RemoveRV(const LoopRV &loop_rv)=0
Remove a loop random variable from the symbol table.
static constexpr const char * _type_key
Definition: schedule.h:108
virtual void Bind(const LoopRV &loop_rv, const String &thread_axis)=0
Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in sho...
virtual void Annotate(const BlockRV &block_rv, const String &ann_key, const ObjectRef &ann_val)=0
Annotate a block with a key value pair.
virtual void Unroll(const LoopRV &loop_rv)=0
Unroll the input loop. It requires nothing.
virtual void AnnotateBufferAccess(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map)=0
Annotate the buffer access of a block.
virtual void StorageAlign(const BlockRV &block_rv, int buffer_index, int axis, int factor, int offset)=0
Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for so...
virtual Optional< Trace > trace() const =0
virtual support::LinearCongruentialEngine::TRandState ForkSeed()=0
Fork the random state.
virtual Array< BlockRV > GetOutputBlocks(const BlockRV &scope_block_rv)=0
Get the list of output blocks within the given scope An output block is a block which has atleast one...
virtual void TransformBlockLayout(const BlockRV &block_rv, const IndexMap &index_map)=0
Apply a transformation represented by IndexMap to block.
virtual void ReverseComputeInline(const BlockRV &block)=0
Inline a block into its only producer. It requires: 1) The block is a complete non-root block,...
virtual BlockRV ReIndex(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type)=0
Create a block that read/write a buffer region into a read/write cache with reindexing....
virtual BlockRV CacheWrite(const BlockRV &block_rv, int write_buffer_index, const String &storage_scope, const Array< BlockRV > consumer_blocks={})=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void SetScope(const BlockRV &block_rv, int buffer_index, const String &storage_scope)=0
Set the storage scope of a buffer, where the buffer is specified by a block and a write-index.
virtual void Annotate(const LoopRV &loop_rv, const String &ann_key, const ObjectRef &ann_val)=0
Annotate a loop with a key value pair.
virtual BlockRV RFactor(const LoopRV &loop_rv, int factor_axis)=0
Factorize an associative reduction block by the specified loop.
virtual void RollingBuffer(const BlockRV &block_rv, int write_buffer_index)=0
Compute the target buffer via rolling buffering.
virtual Array< LoopRV > Split(const LoopRV &loop_rv, const Array< Optional< ExprRV >> &factors, bool preserve_unit_iters=true, bool disable_predication=false)=0
Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thre...
virtual Array< BlockRV > GetChildBlocks(const BlockRV &block_rv)=0
Get the leaf blocks of a specific scope.
virtual void Vectorize(const LoopRV &loop_rv)=0
Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipel...
virtual ScheduleState state() const =0
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object)
virtual Array< BlockRV > CacheInplace(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope)=0
Create 2 blocks that read&write a buffer region into a read/write cache. It requires the target block...
virtual LoopRV SampleComputeLocation(const BlockRV &block_rv, Optional< Integer > decision=NullOpt)=0
Sample a compute-at location of the given block.
virtual Array< BlockRV > GetConsumers(const BlockRV &block_rv)=0
Get the consumers of a specific block, under the same block scope.
virtual BlockRV GetBlock(const String &name, const Optional< String > &func_name=NullOpt)=0
Retrieve a block in a specific function with its name.
virtual void UnsafeSetDType(const BlockRV &block_rv, int buffer_index, const String &dtype)=0
Set the data type of a buffer, where the buffer is specified by a block and a write-index.
virtual bool HasBlock(const BlockRV &block_rv) const =0
Check the existance of a specific BlockRV.
virtual void WorkOn(const String &func_name)=0
Instruct the schedule to work on a function in the IRModule.
virtual ExprRV SampleCategorical(const Array< runtime::Int > &candidates, const Array< runtime::Float > &probs, Optional< runtime::Int > decision=NullOpt)=0
Sample an integer given the probability distribution.
virtual BlockRV Blockize(const LoopRV &loop_rv, bool preserve_unit_iters=true)=0
Convert the subtree rooted at a specific loop into a block.
virtual BlockRV WriteAt(const LoopRV &loop_rv, const BlockRV &block_rv, int write_buffer_index, const String &storage_scope)=0
virtual void ReverseComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a consumer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual BlockRV Blockize(const Array< BlockRV > &blocks, bool preserve_unit_iters=true)=0
Convert specified blocks into a nested block.
virtual For Get(const LoopRV &loop_rv) const =0
Get the for loop corresponding to the specific LoopRV.
virtual ~ScheduleNode()=default
virtual void UnsafeHideBufferAccess(const BlockRV &block_rv, const String &buf_type, const Array< IntImm > &buf_index_array)=0
Hide some buffer access in the given block.
virtual Block Get(const BlockRV &block_rv) const =0
Get the block corresponding to the specific BlockRV.
virtual BlockRV ReindexCacheWrite(const BlockRV &block_rv, int write_buffer_index, const String &storage_scope, const IndexMap &index_map)=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void TransformLayout(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map, const Optional< IndexMap > &pad_value=NullOpt, bool assume_injective_transform=false)=0
Apply a transformation represented by IndexMap to buffer.
virtual BlockRV DecomposePadding(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a padding block into a block filling const pad values and a block writing in-bound values.
virtual void ComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a producer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual void Tensorize(const BlockRV &block_rv, const String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
Managed reference to ScheduleStateNode.
Definition: state.h:208
Managed reference to ScheduleNode.
Definition: schedule.h:877
static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a concrete TensorIR schedule from an IRModule.
static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a traced concrete TensorIR schedule from an IRModule.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode)
Base node of all statements.
Definition: stmt.h:38
Managed reference to StmtSRefNode.
Definition: block_scope.h:107
Container of all statements.
Definition: stmt.h:59
Defines a remapping of buffer indices.
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
BufferIndexType
Type of buffer index.
Definition: schedule.h:41
@ kRead
Index of a read buffer.
@ kWrite
Index of a written buffer.
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:31
@ kNone
No error message at all.
@ kDetail
Render a detailed error message.
@ kFast
Render the error in fast mode.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
This file defines ScheduleState, the core data structure of TensorIR scheduling.