19 #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_TIR_SCHEDULE_SCHEDULE_H_
63 class BlockRV :
public runtime::ObjectRef {
85 class LoopRV :
public runtime::ObjectRef {
119 virtual ffi::Optional<Trace>
trace()
const = 0;
136 virtual void WorkOn(
const ffi::String& func_name) = 0;
230 const ffi::Array<FloatImm>& probs,
231 ffi::Optional<Integer> decision = std::nullopt) = 0;
241 const LoopRV& loop_rv,
int n,
int max_innermost_factor,
242 ffi::Optional<ffi::Array<Integer>> decision = std::nullopt) = 0;
259 const LoopRV& loop_rv,
int n,
int partition_pos,
int innerpart_factor,
260 ffi::Optional<ffi::Array<Integer>> decision = std::nullopt) = 0;
268 ffi::Optional<Integer> decision = std::nullopt) = 0;
286 const ffi::Optional<ffi::String>& func_name = std::nullopt) = 0;
349 virtual LoopRV Fuse(
const ffi::Array<LoopRV>& loop_rvs,
bool preserve_unit_iters =
true) = 0;
363 const ffi::Array<ffi::Optional<ExprRV>>& factors,
364 bool preserve_unit_iters =
true,
365 bool disable_predication =
false) = 0;
376 const ffi::Array<ffi::Optional<ExprRV>>& factors,
377 bool preserve_unit_iters =
true) = 0;
390 virtual void Reorder(
const ffi::Array<LoopRV>& ordered_loop_rvs) = 0;
397 const ffi::Array<Integer> new_order) = 0;
460 const ffi::String& storage_scope,
461 const ffi::Array<BlockRV> consumer_blocks = {}) = 0;
473 const ffi::String& storage_scope,
474 const ffi::Array<BlockRV> consumer_blocks = {}) = 0;
488 const ffi::String& storage_scope,
const IndexMap& index_map) = 0;
502 const ffi::String& storage_scope,
513 const ffi::String& storage_scope) = 0;
522 virtual ffi::Array<BlockRV>
CacheIndex(
const BlockRV& block_rv,
const ffi::String& storage_scope,
539 const ffi::String& storage_scope) = 0;
541 const ffi::String& storage_scope) = 0;
585 bool preserve_unit_loops,
int index = -1) = 0;
667 const ffi::String& storage_scope) = 0;
678 const ffi::String& dtype) = 0;
693 virtual BlockRV Blockize(
const ffi::Array<BlockRV>& blocks,
bool preserve_unit_iters =
true) = 0;
701 bool preserve_unit_iters =
true) = 0;
709 bool preserve_unit_iters =
true) = 0;
718 virtual void Annotate(
const LoopRV& loop_rv,
const ffi::String& ann_key,
const Any& ann_val) = 0;
726 const Any& ann_val) = 0;
774 const ffi::Optional<IndexMap>& pad_value = std::nullopt,
775 bool assume_injective_transform =
false) = 0;
867 const ffi::Array<IntImm>& buf_index_array) = 0;
902 bool enable_check =
true);
919 bool enable_check =
true);
Managed reference class to IRModuleNode.
Definition: module.h:256
Base node of all primitive expressions.
Definition: expr.h:91
Reference to PrimExprNode.
Definition: expr.h:124
int64_t TRandState
Definition: random_engine.h:46
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:51
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRV", BlockRVNode, runtime::Object)
static void RegisterReflection()
Definition: schedule.h:53
Managed reference to BlockRVNode.
Definition: schedule.h:63
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockRV, runtime::ObjectRef, BlockRVNode)
Managed reference to BlockNode.
Definition: stmt.h:976
Managed reference to ForNode.
Definition: stmt.h:770
Definition: index_map.h:169
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:73
static void RegisterReflection()
Definition: schedule.h:75
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LoopRV", LoopRVNode, runtime::Object)
Managed reference to LoopRVNode.
Definition: schedule.h:85
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LoopRV, runtime::ObjectRef, LoopRVNode)
The user-facing schedule class.
Definition: schedule.h:104
virtual void RemoveRV(const ExprRV &expr_rv)=0
Remove an integer random variable from the symbol table.
virtual LoopRV Merge(const ffi::Array< LoopRV > &loop_rvs)=0
Merge a list of loops into one. The loops under their LCA requires: 1) Under the same scope 2) Can't ...
virtual void Annotate(const LoopRV &loop_rv, const ffi::String &ann_key, const Any &ann_val)=0
Annotate a loop with a key value pair.
virtual StmtSRef GetSRef(const LoopRV &loop_rv) const =0
Get the loop sref corresponding to the specific LoopRV.
virtual BlockRV DecomposeReduction(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a reduction block into two separate blocks. a) The init block, which is translated from the...
virtual BlockRV CacheWrite(const BlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope, const ffi::Array< BlockRV > consumer_blocks={})=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual LoopRV AddUnitLoop(const BlockRV &block_rv)=0
Create a new unit loop on top of the specific block.
virtual void PadEinsum(const BlockRV &block_rv, const ffi::Array< Integer > &padding)=0
Pad the computation of Einsum.
virtual PrimExpr Get(const ExprRV &expr_rv) const =0
Get the expr corresponding to the specific random variable.
virtual void EnterPostproc()=0
A no-op that marks the start of postprocessing phase of scheduling.
virtual ffi::Optional< Trace > trace() const =0
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:203
virtual LoopRV AddUnitLoop(const LoopRV &loop_rv)=0
Create a new unit loop on top of the specific loop.
virtual BlockRV ReindexCacheRead(const BlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope, const IndexMap &index_map)=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual StmtSRef GetSRef(const StmtNode *stmt) const
Get the block/loop sref corresponding to the specific statement.
virtual StmtSRef GetSRef(const BlockRV &block_rv) const =0
Get the block sref corresponding to the specific BlockRV.
virtual void Seed(support::LinearCongruentialEngine::TRandState seed)=0
Seed the randomness.
virtual void TransformLayout(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map, const ffi::Optional< IndexMap > &pad_value=std::nullopt, bool assume_injective_transform=false)=0
Apply a transformation represented by IndexMap to buffer.
virtual ffi::Array< LoopRV > GetLoops(const BlockRV &block_rv)=0
Get the parent loops of the block in its scope, from outer to inner.
virtual ffi::Array< BlockRV > GetConsumers(const BlockRV &block_rv)=0
Get the consumers of a specific block, under the same block scope.
virtual void Parallel(const LoopRV &loop_rv)=0
Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pip...
virtual ffi::Array< LoopRV > Split(const LoopRV &loop_rv, const ffi::Array< ffi::Optional< ExprRV >> &factors, bool preserve_unit_iters=true, bool disable_predication=false)=0
Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thre...
virtual void WorkOn(const ffi::String &func_name)=0
Instruct the schedule to work on a function in the IRModule.
virtual BlockRV ReadAt(const LoopRV &loop_rv, const BlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope)=0
virtual void SetAxisSeparator(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array< IntImm > &axis_separators)=0
Set the axis separator of a buffer, where the buffer is specified by a block and a read or write inde...
virtual void ComputeInline(const BlockRV &block)=0
Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block,...
virtual void Bind(const LoopRV &loop_rv, const ffi::String &thread_axis)=0
Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in sho...
virtual ExprRV SampleCategorical(const ffi::Array< Integer > &candidates, const ffi::Array< FloatImm > &probs, ffi::Optional< Integer > decision=std::nullopt)=0
Sample an integer given the probability distribution.
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:115
virtual void RemoveRV(const BlockRV &block_rv)=0
Remove a block random variable from the symbol table.
virtual Schedule Copy()=0
Returns a copy of the schedule, including both its state and its symbol table, guaranteeing that 1) S...
virtual ffi::Array< ExprRV > SamplePartitionedTile(const LoopRV &loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional< ffi::Array< Integer >> decision=std::nullopt)=0
Sample the factors to a partitioned tile for a specific loop.
virtual void RemoveRV(const LoopRV &loop_rv)=0
Remove a loop random variable from the symbol table.
virtual BlockRV CacheRead(const BlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope, const ffi::Array< BlockRV > consumer_blocks={})=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual LoopRV Fuse(const ffi::Array< LoopRV > &loop_rvs, bool preserve_unit_iters=true)=0
Fuse a list of consecutive loops into one. It requires: 1) The loops can't have annotations or thread...
virtual void Tensorize(const BlockRV &block_rv, const ffi::String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual void Unroll(const LoopRV &loop_rv)=0
Unroll the input loop. It requires nothing.
virtual void Reorder(const ffi::Array< LoopRV > &ordered_loop_rvs)=0
Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops ar...
virtual LoopRV SampleComputeLocation(const BlockRV &block_rv, ffi::Optional< Integer > decision=std::nullopt)=0
Sample a compute-at location of the given block.
virtual void AnnotateBufferAccess(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map)=0
Annotate the buffer access of a block.
virtual void StorageAlign(const BlockRV &block_rv, int buffer_index, int axis, int factor, int offset)=0
Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for so...
virtual support::LinearCongruentialEngine::TRandState ForkSeed()=0
Fork the random state.
virtual void TransformBlockLayout(const BlockRV &block_rv, const IndexMap &index_map)=0
Apply a transformation represented by IndexMap to block.
virtual void ReverseComputeInline(const BlockRV &block)=0
Inline a block into its only producer. It requires: 1) The block is a complete non-root block,...
virtual ffi::Array< BlockRV > CacheIndex(const BlockRV &block_rv, const ffi::String &storage_scope, int cse_thresh)=0
Create a block to cache precomputed index for later use. if there is no index computation,...
virtual BlockRV Blockize(const ffi::Array< BlockRV > &blocks, bool preserve_unit_iters=true)=0
Convert specified blocks into a nested block.
virtual void Tensorize(const LoopRV &loop_rv, const ffi::String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual BlockRV ReIndex(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type)=0
Create a block that read/write a buffer region into a read/write cache with reindexing....
virtual BlockRV RFactor(const LoopRV &loop_rv, int factor_axis)=0
Factorize an associative reduction block by the specified loop.
virtual void RollingBuffer(const BlockRV &block_rv, int write_buffer_index)=0
Compute the target buffer via rolling buffering.
virtual BlockRV GetBlock(const ffi::String &name, const ffi::Optional< ffi::String > &func_name=std::nullopt)=0
Retrieve a block in a specific function with its name.
virtual void Vectorize(const LoopRV &loop_rv)=0
Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipel...
virtual void ReorderBlockIterVar(const BlockRV &block_rv, const ffi::Array< Integer > new_order)=0
Reorder the itervars inside a block.
virtual ffi::Array< BlockRV > GetChildBlocks(const BlockRV &block_rv)=0
Get the leaf blocks of a specific scope.
virtual void SetScope(const BlockRV &block_rv, int buffer_index, const ffi::String &storage_scope)=0
Set the storage scope of a buffer, where the buffer is specified by a block and a write-index.
virtual ScheduleState state() const =0
virtual ffi::Array< BlockRV > CacheInplace(const BlockRV &block_rv, int read_buffer_index, const ffi::String &storage_scope)=0
Create 2 blocks that read&write a buffer region into a read/write cache. It requires the target block...
virtual void UnsafeSetDType(const BlockRV &block_rv, int buffer_index, const ffi::String &dtype)=0
Set the data type of a buffer, where the buffer is specified by a block and a write-index.
virtual ffi::Array< BlockRV > GetOutputBlocks(const BlockRV &scope_block_rv)=0
Get the list of output blocks within the given scope An output block is a block which has atleast one...
virtual bool HasBlock(const BlockRV &block_rv) const =0
Check the existance of a specific BlockRV.
virtual void Annotate(const BlockRV &block_rv, const ffi::String &ann_key, const Any &ann_val)=0
Annotate a block with a key value pair.
virtual ffi::Array< ExprRV > SamplePerfectTile(const LoopRV &loop_rv, int n, int max_innermost_factor, ffi::Optional< ffi::Array< Integer >> decision=std::nullopt)=0
Sample the factors to perfect tile a specific loop.
virtual BlockRV WriteAt(const LoopRV &loop_rv, const BlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope)=0
virtual void Unannotate(const BlockRV &block_rv, const ffi::String &ann_key)=0
Unannotate a block's annotation with key ann_key.
virtual BlockRV Blockize(const LoopRV &loop_rv, bool preserve_unit_iters=true)=0
Convert the subtree rooted at a specific loop into a block.
virtual ffi::Array< LoopRV > LoopPartition(const LoopRV &loop_rv, const ffi::Array< ffi::Optional< ExprRV >> &factors, bool preserve_unit_iters=true)=0
Partition the loops into sequence of multiple loops 1) The loop can't have annotation or thread bindi...
virtual void ReverseComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a consumer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual ffi::Optional< GlobalVar > func_working_on() const =0
virtual For Get(const LoopRV &loop_rv) const =0
Get the for loop corresponding to the specific LoopRV.
virtual ~ScheduleNode()=default
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Schedule", ScheduleNode, runtime::Object)
virtual Block Get(const BlockRV &block_rv) const =0
Get the block corresponding to the specific BlockRV.
virtual BlockRV ReindexCacheWrite(const BlockRV &block_rv, int write_buffer_index, const ffi::String &storage_scope, const IndexMap &index_map)=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void UnsafeHideBufferAccess(const BlockRV &block_rv, const ffi::String &buf_type, const ffi::Array< IntImm > &buf_index_array)=0
Hide some buffer access in the given block.
static constexpr const bool _type_mutable
Definition: schedule.h:110
virtual BlockRV DecomposePadding(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a padding block into a block filling const pad values and a block writing in-bound values.
virtual void Unannotate(const LoopRV &loop_rv, const ffi::String &ann_key)=0
Unannotate a loop's annotation with key ann_key.
virtual void ComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a producer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual ffi::Array< BlockRV > GetChildBlocks(const LoopRV &loop_rv)=0
Get the leaf blocks of under a specific loop.
virtual ffi::Array< BlockRV > GetProducers(const BlockRV &block_rv)=0
Get the producer of a specific block, under the same block scope.
Managed reference to ScheduleStateNode.
Definition: state.h:210
Managed reference to ScheduleNode.
Definition: schedule.h:885
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef, ScheduleNode)
static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a concrete TensorIR schedule from an IRModule.
static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a traced concrete TensorIR schedule from an IRModule.
Base node of all statements.
Definition: stmt.h:38
Managed reference to StmtSRefNode.
Definition: block_scope.h:106
Container of all statements.
Definition: stmt.h:63
Defines a remapping of buffer indices.
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1094
BufferIndexType
Type of buffer index.
Definition: schedule.h:41
@ kRead
Index of a read buffer.
@ kWrite
Index of a written buffer.
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:31
@ kNone
No error message at all.
@ kDetail
Render a detailed error message.
@ kFast
Render the error in fast mode.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
This file defines ScheduleState, the core data structure of TensorIR scheduling.