tvm
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_S_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_S_TIR_SCHEDULE_SCHEDULE_H_
21 
25 #include <tvm/tir/index_map.h>
26 
27 namespace tvm {
28 namespace s_tir {
29 using namespace tvm::tir;
30 
32 enum class ScheduleErrorRenderLevel : int32_t {
34  kDetail = 0,
36  kFast = 1,
38  kNone = 2,
39 };
40 
42 enum class BufferIndexType : int32_t {
44  kRead = 0,
46  kWrite = 1,
47 };
48 
49 /**************** Random variable: SBlockRV ****************/
50 
52 class SBlockRVNode : public runtime::Object {
53  public:
54  static void RegisterReflection() {
55  namespace refl = tvm::ffi::reflection;
56  refl::ObjectDef<SBlockRVNode>();
57  }
58  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockRV", SBlockRVNode, runtime::Object);
59 };
60 
65 class SBlockRV : public runtime::ObjectRef {
66  public:
68  TVM_DLL SBlockRV();
70 };
71 
72 /**************** Random variable: LoopRV ****************/
73 
75 class LoopRVNode : public runtime::Object {
76  public:
77  static void RegisterReflection() {
78  namespace refl = tvm::ffi::reflection;
79  refl::ObjectDef<LoopRVNode>();
80  }
81  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.LoopRV", LoopRVNode, runtime::Object);
82 };
83 
88 class LoopRV : public runtime::ObjectRef {
89  public:
91  TVM_DLL LoopRV();
93 };
94 
95 /**************** Random variable: ExprRV ****************/
96 
98 using ExprRV = PrimExpr;
99 
101 
102 /**************** The Schedule class ****************/
103 
104 class Schedule;
105 
107 class ScheduleNode : public runtime::Object {
108  friend class Schedule;
109 
110  public:
111  virtual ~ScheduleNode() = default;
112 
113  static constexpr const bool _type_mutable = true;
114  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Schedule", ScheduleNode, runtime::Object);
115 
116  public:
118  virtual IRModule mod() const { return state()->mod; }
120  virtual ScheduleState state() const = 0;
122  virtual ffi::Optional<Trace> trace() const = 0;
124  virtual ffi::Optional<GlobalVar> func_working_on() const = 0;
139  virtual void WorkOn(const ffi::String& func_name) = 0;
148  virtual Schedule Copy() = 0;
156 
157  public:
158  /******** Lookup/Remove random variables ********/
164  virtual SBlock Get(const SBlockRV& block_rv) const = 0;
170  virtual For Get(const LoopRV& loop_rv) const = 0;
176  virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
182  virtual StmtSRef GetSRef(const SBlockRV& block_rv) const = 0;
188  virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
194  virtual bool HasBlock(const SBlockRV& block_rv) const = 0;
200  virtual StmtSRef GetSRef(const StmtNode* stmt) const;
206  StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
211  virtual void RemoveRV(const SBlockRV& block_rv) = 0;
216  virtual void RemoveRV(const LoopRV& loop_rv) = 0;
221  virtual void RemoveRV(const ExprRV& expr_rv) = 0;
222 
223  public:
224  /******** Schedule: Sampling ********/
232  virtual ExprRV SampleCategorical(const ffi::Array<Integer>& candidates,
233  const ffi::Array<FloatImm>& probs,
234  ffi::Optional<Integer> decision = std::nullopt) = 0;
243  virtual ffi::Array<ExprRV> SamplePerfectTile(
244  const LoopRV& loop_rv, int n, int max_innermost_factor,
245  ffi::Optional<ffi::Array<Integer>> decision = std::nullopt) = 0;
261  virtual ffi::Array<ExprRV> SamplePartitionedTile(
262  const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor,
263  ffi::Optional<ffi::Array<Integer>> decision = std::nullopt) = 0;
270  virtual LoopRV SampleComputeLocation(const SBlockRV& block_rv,
271  ffi::Optional<Integer> decision = std::nullopt) = 0;
272 
273  /******** Schedule: Get blocks & loops ********/
288  virtual SBlockRV GetSBlock(const ffi::String& name,
289  const ffi::Optional<ffi::String>& func_name = std::nullopt) = 0;
295  virtual ffi::Array<LoopRV> GetLoops(const SBlockRV& block_rv) = 0;
301  virtual ffi::Array<SBlockRV> GetChildBlocks(const SBlockRV& block_rv) = 0;
307  virtual ffi::Array<SBlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
314  virtual ffi::Array<SBlockRV> GetProducers(const SBlockRV& block_rv) = 0;
321  virtual ffi::Array<SBlockRV> GetConsumers(const SBlockRV& block_rv) = 0;
330  virtual ffi::Array<SBlockRV> GetOutputBlocks(const SBlockRV& scope_block_rv) = 0;
331  /******** Schedule: Transform loops ********/
341  virtual LoopRV Merge(const ffi::Array<LoopRV>& loop_rvs) = 0;
352  virtual LoopRV Fuse(const ffi::Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
365  virtual ffi::Array<LoopRV> Split(const LoopRV& loop_rv,
366  const ffi::Array<ffi::Optional<ExprRV>>& factors,
367  bool preserve_unit_iters = true,
368  bool disable_predication = false) = 0;
378  virtual ffi::Array<LoopRV> LoopPartition(const LoopRV& loop_rv,
379  const ffi::Array<ffi::Optional<ExprRV>>& factors,
380  bool preserve_unit_iters = true) = 0;
393  virtual void Reorder(const ffi::Array<LoopRV>& ordered_loop_rvs) = 0;
399  virtual void ReorderBlockIterVar(const SBlockRV& block_rv,
400  const ffi::Array<Integer> new_order) = 0;
406  virtual LoopRV AddUnitLoop(const SBlockRV& block_rv) = 0;
412  virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
413  /******** Schedule: Manipulate ForKind ********/
423  virtual void Parallel(const LoopRV& loop_rv) = 0;
433  virtual void Vectorize(const LoopRV& loop_rv) = 0;
445  virtual void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) = 0;
450  virtual void Unroll(const LoopRV& loop_rv) = 0;
451  /******** Schedule: Insert cache stages ********/
462  virtual SBlockRV CacheRead(const SBlockRV& block_rv, int read_buffer_index,
463  const ffi::String& storage_scope,
464  const ffi::Array<SBlockRV> consumer_blocks = {}) = 0;
475  virtual SBlockRV CacheWrite(const SBlockRV& block_rv, int write_buffer_index,
476  const ffi::String& storage_scope,
477  const ffi::Array<SBlockRV> consumer_blocks = {}) = 0;
490  virtual SBlockRV ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index,
491  const ffi::String& storage_scope,
492  const IndexMap& index_map) = 0;
505  virtual SBlockRV ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index,
506  const ffi::String& storage_scope,
507  const IndexMap& index_map) = 0;
516  virtual ffi::Array<SBlockRV> CacheInplace(const SBlockRV& block_rv, int read_buffer_index,
517  const ffi::String& storage_scope) = 0;
526  virtual ffi::Array<SBlockRV> CacheIndex(const SBlockRV& block_rv,
527  const ffi::String& storage_scope, int cse_thresh) = 0;
539  virtual SBlockRV ReIndex(const SBlockRV& block_rv, int buffer_index,
540  BufferIndexType buffer_index_type) = 0;
541  /******** Schedule: Data movement ********/
542  virtual SBlockRV ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int read_buffer_index,
543  const ffi::String& storage_scope) = 0;
544  virtual SBlockRV WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int write_buffer_index,
545  const ffi::String& storage_scope) = 0;
546  /******** Schedule: Compute location ********/
567  virtual void ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
568  int index = -1) = 0;
588  virtual void ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv,
589  bool preserve_unit_loops, int index = -1) = 0;
600  virtual void ComputeInline(const SBlockRV& block) = 0;
612  virtual void ReverseComputeInline(const SBlockRV& block) = 0;
618  virtual void FuseReductionEpilogue(const SBlockRV& reduction_block,
619  const SBlockRV& epilogue_block) = 0;
620  /******** Schedule: Reduction ********/
636  virtual SBlockRV DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) = 0;
654  virtual SBlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
655  /******** Schedule: SBlock annotation ********/
668  virtual void StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor,
669  int offset) = 0;
677  virtual void SetScope(const SBlockRV& block_rv, int buffer_index,
678  const ffi::String& storage_scope) = 0;
688  virtual void UnsafeSetDType(const SBlockRV& block_rv, int buffer_index,
689  const ffi::String& dtype) = 0;
690  /******** Schedule: Blockize & Tensorize ********/
697  virtual SBlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
704  virtual SBlockRV Blockize(const ffi::Array<SBlockRV>& blocks,
705  bool preserve_unit_iters = true) = 0;
712  virtual void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin,
713  bool preserve_unit_iters = true) = 0;
720  virtual void Tensorize(const SBlockRV& block_rv, const ffi::String& intrin,
721  bool preserve_unit_iters = true) = 0;
722 
723  /******** Schedule: Annotation ********/
730  virtual void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) = 0;
737  virtual void Annotate(const SBlockRV& block_rv, const ffi::String& ann_key,
738  const Any& ann_val) = 0;
744  virtual void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) = 0;
750  virtual void Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) = 0;
751 
752  /******** Schedule: Layout transformation ********/
784  virtual void TransformLayout(const SBlockRV& block_rv, int buffer_index,
785  BufferIndexType buffer_index_type, const IndexMap& index_map,
786  const ffi::Optional<IndexMap>& pad_value = std::nullopt,
787  bool assume_injective_transform = false) = 0;
788 
797  virtual void TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) = 0;
798 
807  virtual void SetAxisSeparator(const SBlockRV& block_rv, int buffer_index,
808  BufferIndexType buffer_index_type,
809  const ffi::Array<IntImm>& axis_separators) = 0;
810 
811  /******** Schedule: Padding ********/
819  virtual SBlockRV DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) = 0;
820 
838  virtual void PadEinsum(const SBlockRV& block_rv, const ffi::Array<Integer>& padding) = 0;
839 
840  /******** Schedule: Buffer transformation ********/
855  virtual void RollingBuffer(const SBlockRV& block_rv, int write_buffer_index) = 0;
856 
864  virtual void AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index,
865  BufferIndexType buffer_index_type,
866  const IndexMap& index_map) = 0;
867 
868  /******** Schedule: Misc ********/
870  virtual void EnterPostproc() = 0;
871 
878  virtual void UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type,
879  const ffi::Array<IntImm>& buf_index_array) = 0;
880 };
881 
897 class Schedule : public runtime::ObjectRef {
898  public:
913  int debug_mask, ScheduleErrorRenderLevel error_render_level,
914  bool enable_check = true);
930  int debug_mask, ScheduleErrorRenderLevel error_render_level,
931  bool enable_check = true);
933 };
934 
935 } // namespace s_tir
936 } // namespace tvm
937 
938 #endif // TVM_S_TIR_SCHEDULE_SCHEDULE_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
Base node of all primitive expressions.
Definition: expr.h:91
Reference to PrimExprNode.
Definition: expr.h:124
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:75
static void RegisterReflection()
Definition: schedule.h:77
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.LoopRV", LoopRVNode, runtime::Object)
Managed reference to LoopRVNode.
Definition: schedule.h:88
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LoopRV, runtime::ObjectRef, LoopRVNode)
LoopRV()
Constructor.
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:52
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockRV", SBlockRVNode, runtime::Object)
static void RegisterReflection()
Definition: schedule.h:54
Managed reference to SBlockRVNode.
Definition: schedule.h:65
SBlockRV()
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockRV, runtime::ObjectRef, SBlockRVNode)
The user-facing schedule class.
Definition: schedule.h:107
virtual ~ScheduleNode()=default
virtual SBlockRV RFactor(const LoopRV &loop_rv, int factor_axis)=0
Factorize an associative reduction block by the specified loop.
virtual void Unroll(const LoopRV &loop_rv)=0
Unroll the input loop. It requires nothing.
virtual void Bind(const LoopRV &loop_rv, const ffi::String &thread_axis)=0
Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in sho...
virtual ffi::Array< LoopRV > Split(const LoopRV &loop_rv, const ffi::Array< ffi::Optional< ExprRV >> &factors, bool preserve_unit_iters=true, bool disable_predication=false)=0
Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thre...
virtual ffi::Array< LoopRV > LoopPartition(const LoopRV &loop_rv, const ffi::Array< ffi::Optional< ExprRV >> &factors, bool preserve_unit_iters=true)=0
Partition the loops into sequence of multiple loops 1) The loop can't have annotation or thread bindi...
virtual ScheduleState state() const =0
virtual void ReverseComputeAt(const SBlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a consumer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual ffi::Optional< Trace > trace() const =0
virtual ffi::Array< SBlockRV > GetChildBlocks(const SBlockRV &block_rv)=0
Get the leaf blocks of a specific scope.
virtual void ReorderBlockIterVar(const SBlockRV &block_rv, const ffi::Array< Integer > new_order)=0
Reorder the itervars inside a block.
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:206
virtual void StorageAlign(const SBlockRV &block_rv, int buffer_index, int axis, int factor, int offset)=0
Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for so...
virtual void RemoveRV(const SBlockRV &block_rv)=0
Remove a block random variable from the symbol table.
virtual SBlockRV CacheRead(const SBlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope, const ffi::Array< SBlockRV > consumer_blocks={})=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual ffi::Optional< GlobalVar > func_working_on() const =0
virtual SBlockRV Blockize(const LoopRV &loop_rv, bool preserve_unit_iters=true)=0
Convert the subtree rooted at a specific loop into a block.
virtual void UnsafeSetDType(const SBlockRV &block_rv, int buffer_index, const ffi::String &dtype)=0
Set the data type of a buffer, where the buffer is specified by a block and a write-index.
virtual void Parallel(const LoopRV &loop_rv)=0
Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pip...
virtual PrimExpr Get(const ExprRV &expr_rv) const =0
Get the expr corresponding to the specific random variable.
virtual ffi::Array< SBlockRV > CacheIndex(const SBlockRV &block_rv, const ffi::String &storage_scope, int cse_thresh)=0
Create a block to cache precomputed index for later use. if there is no index computation,...
virtual void Vectorize(const LoopRV &loop_rv)=0
Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipel...
virtual SBlockRV ReIndex(const SBlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type)=0
Create a block that read/write a buffer region into a read/write cache with reindexing....
virtual StmtSRef GetSRef(const SBlockRV &block_rv) const =0
Get the block sref corresponding to the specific SBlockRV.
virtual ffi::Array< ExprRV > SamplePartitionedTile(const LoopRV &loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional< ffi::Array< Integer >> decision=std::nullopt)=0
Sample the factors to a partitioned tile for a specific loop.
virtual ffi::Array< SBlockRV > GetConsumers(const SBlockRV &block_rv)=0
Get the consumers of a specific block, under the same block scope.
virtual void AnnotateBufferAccess(const SBlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map)=0
Annotate the buffer access of a block.
virtual void UnsafeHideBufferAccess(const SBlockRV &block_rv, const ffi::String &buf_type, const ffi::Array< IntImm > &buf_index_array)=0
Hide some buffer access in the given block.
virtual ffi::Array< SBlockRV > GetProducers(const SBlockRV &block_rv)=0
Get the producer of a specific block, under the same block scope.
virtual bool HasBlock(const SBlockRV &block_rv) const =0
Check the existance of a specific SBlockRV.
virtual void SetScope(const SBlockRV &block_rv, int buffer_index, const ffi::String &storage_scope)=0
Set the storage scope of a buffer, where the buffer is specified by a block and a write-index.
virtual void RemoveRV(const ExprRV &expr_rv)=0
Remove an integer random variable from the symbol table.
virtual ffi::Array< SBlockRV > GetChildBlocks(const LoopRV &loop_rv)=0
Get the leaf blocks of under a specific loop.
virtual LoopRV AddUnitLoop(const SBlockRV &block_rv)=0
Create a new unit loop on top of the specific block.
virtual void ReverseComputeInline(const SBlockRV &block)=0
Inline a block into its only producer. It requires: 1) The block is a complete non-root block,...
virtual void ComputeInline(const SBlockRV &block)=0
Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block,...
virtual void Unannotate(const LoopRV &loop_rv, const ffi::String &ann_key)=0
Unannotate a loop's annotation with key ann_key.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Schedule", ScheduleNode, runtime::Object)
virtual LoopRV Fuse(const ffi::Array< LoopRV > &loop_rvs, bool preserve_unit_iters=true)=0
Fuse a list of consecutive loops into one. It requires: 1) The loops can't have annotations or thread...
virtual void SetAxisSeparator(const SBlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array< IntImm > &axis_separators)=0
Set the axis separator of a buffer, where the buffer is specified by a block and a read or write inde...
virtual SBlockRV Blockize(const ffi::Array< SBlockRV > &blocks, bool preserve_unit_iters=true)=0
Convert specified blocks into a nested block.
virtual void ComputeAt(const SBlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a producer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual ffi::Array< SBlockRV > GetOutputBlocks(const SBlockRV &scope_block_rv)=0
Get the list of output blocks within the given scope An output block is a block which has atleast one...
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:118
virtual SBlockRV ReadAt(const LoopRV &loop_rv, const SBlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope)=0
virtual StmtSRef GetSRef(const StmtNode *stmt) const
Get the block/loop sref corresponding to the specific statement.
virtual void TransformBlockLayout(const SBlockRV &block_rv, const IndexMap &index_map)=0
Apply a transformation represented by IndexMap to block.
virtual ffi::Array< ExprRV > SamplePerfectTile(const LoopRV &loop_rv, int n, int max_innermost_factor, ffi::Optional< ffi::Array< Integer >> decision=std::nullopt)=0
Sample the factors to perfect tile a specific loop.
virtual void EnterPostproc()=0
A no-op that marks the start of postprocessing phase of scheduling.
virtual ffi::Array< SBlockRV > CacheInplace(const SBlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope)=0
Create 2 blocks that read&write a buffer region into a read/write cache. It requires the target block...
virtual LoopRV SampleComputeLocation(const SBlockRV &block_rv, ffi::Optional< Integer > decision=std::nullopt)=0
Sample a compute-at location of the given block.
virtual SBlockRV GetSBlock(const ffi::String &name, const ffi::Optional< ffi::String > &func_name=std::nullopt)=0
Retrieve a block in a specific function with its name.
virtual void WorkOn(const ffi::String &func_name)=0
Instruct the schedule to work on a function in the IRModule.
virtual SBlock Get(const SBlockRV &block_rv) const =0
Get the block corresponding to the specific SBlockRV.
virtual Schedule Copy()=0
Returns a copy of the schedule, including both its state and its symbol table, guaranteeing that 1) S...
virtual void FuseReductionEpilogue(const SBlockRV &reduction_block, const SBlockRV &epilogue_block)=0
Fuse an epilogue block into a reduction block.
virtual void Reorder(const ffi::Array< LoopRV > &ordered_loop_rvs)=0
Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops ar...
virtual void RemoveRV(const LoopRV &loop_rv)=0
Remove a loop random variable from the symbol table.
virtual void Tensorize(const LoopRV &loop_rv, const ffi::String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual SBlockRV DecomposePadding(const SBlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a padding block into a block filling const pad values and a block writing in-bound values.
virtual For Get(const LoopRV &loop_rv) const =0
Get the for loop corresponding to the specific LoopRV.
virtual SBlockRV ReindexCacheWrite(const SBlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope, const IndexMap &index_map)=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void Annotate(const LoopRV &loop_rv, const ffi::String &ann_key, const Any &ann_val)=0
Annotate a loop with a key value pair.
virtual SBlockRV WriteAt(const LoopRV &loop_rv, const SBlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope)=0
virtual void Unannotate(const SBlockRV &block_rv, const ffi::String &ann_key)=0
Unannotate a block's annotation with key ann_key.
virtual LoopRV AddUnitLoop(const LoopRV &loop_rv)=0
Create a new unit loop on top of the specific loop.
virtual void Seed(support::LinearCongruentialEngine::TRandState seed)=0
Seed the randomness.
virtual ffi::Array< LoopRV > GetLoops(const SBlockRV &block_rv)=0
Get the parent loops of the block in its scope, from outer to inner.
virtual void PadEinsum(const SBlockRV &block_rv, const ffi::Array< Integer > &padding)=0
Pad the computation of Einsum.
virtual SBlockRV CacheWrite(const SBlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope, const ffi::Array< SBlockRV > consumer_blocks={})=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void TransformLayout(const SBlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map, const ffi::Optional< IndexMap > &pad_value=std::nullopt, bool assume_injective_transform=false)=0
Apply a transformation represented by IndexMap to buffer.
virtual void Annotate(const SBlockRV &block_rv, const ffi::String &ann_key, const Any &ann_val)=0
Annotate a block with a key value pair.
virtual LoopRV Merge(const ffi::Array< LoopRV > &loop_rvs)=0
Merge a list of loops into one. The loops under their LCA requires: 1) Under the same scope 2) Can't ...
virtual void RollingBuffer(const SBlockRV &block_rv, int write_buffer_index)=0
Compute the target buffer via rolling buffering.
virtual void Tensorize(const SBlockRV &block_rv, const ffi::String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual SBlockRV ReindexCacheRead(const SBlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope, const IndexMap &index_map)=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual support::LinearCongruentialEngine::TRandState ForkSeed()=0
Fork the random state.
virtual SBlockRV DecomposeReduction(const SBlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a reduction block into two separate blocks. a) The init block, which is translated from the...
virtual StmtSRef GetSRef(const LoopRV &loop_rv) const =0
Get the loop sref corresponding to the specific LoopRV.
virtual ExprRV SampleCategorical(const ffi::Array< Integer > &candidates, const ffi::Array< FloatImm > &probs, ffi::Optional< Integer > decision=std::nullopt)=0
Sample an integer given the probability distribution.
Managed reference to ScheduleStateNode.
Definition: state.h:211
Managed reference to ScheduleNode.
Definition: schedule.h:897
static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a traced concrete TensorIR schedule from an IRModule.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef, ScheduleNode)
static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a concrete TensorIR schedule from an IRModule.
int64_t TRandState
Definition: random_engine.h:46
Managed reference to ForNode.
Definition: stmt.h:779
Definition: index_map.h:169
Managed reference to SBlockNode.
Definition: stmt.h:985
Base node of all statements.
Definition: stmt.h:38
Managed reference to StmtSRefNode.
Definition: block_scope.h:106
Container of all statements.
Definition: stmt.h:63
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:32
@ kNone
No error message at all.
@ kDetail
Render a detailed error message.
@ kFast
Render the error in fast mode.
BufferIndexType
Type of buffer index.
Definition: schedule.h:42
@ kRead
Index of a read buffer.
@ kWrite
Index of a written buffer.
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1103
Definition: extracted_task.h:30
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
This file defines ScheduleState, the core data structure of TensorIR scheduling.