19 #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_TIR_SCHEDULE_SCHEDULE_H_
54 static constexpr
const char*
_type_key =
"tir.BlockRV";
75 static constexpr
const char*
_type_key =
"tir.LoopRV";
108 static constexpr
const char*
_type_key =
"tir.Schedule";
256 int innerpart_factor,
361 bool preserve_unit_iters =
true,
362 bool disable_predication =
false) = 0;
373 bool preserve_unit_iters =
true) = 0;
455 const String& storage_scope,
468 const String& storage_scope,
507 const String& storage_scope) = 0;
533 const String& storage_scope) = 0;
535 const String& storage_scope) = 0;
579 bool preserve_unit_loops,
int index = -1) = 0;
693 bool preserve_unit_iters =
true) = 0;
701 bool preserve_unit_iters =
true) = 0;
767 bool assume_injective_transform =
false) = 0;
894 bool enable_check =
true);
911 bool enable_check =
true);
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Base node of all primitive expressions.
Definition: expr.h:86
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
int64_t TRandState
Definition: random_engine.h:46
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:51
static constexpr const char * _type_key
Definition: schedule.h:54
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:53
Managed reference to BlockRVNode.
Definition: schedule.h:62
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode)
Managed reference to BlockNode.
Definition: stmt.h:1325
Managed reference to ForNode.
Definition: stmt.h:1029
Definition: index_map.h:176
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:72
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:74
static constexpr const char * _type_key
Definition: schedule.h:75
TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object)
Managed reference to LoopRVNode.
Definition: schedule.h:83
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode)
The user-facing schedule class.
Definition: schedule.h:102
virtual void RemoveRV(const ExprRV &expr_rv)=0
Remove an integer random variable from the symbol table.
virtual void SetAxisSeparator(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array< IntImm > &axis_separators)=0
Set the axis separator of a buffer, where the buffer is specified by a block and a read or write inde...
virtual void Reorder(const Array< LoopRV > &ordered_loop_rvs)=0
Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops ar...
virtual Array< BlockRV > GetChildBlocks(const LoopRV &loop_rv)=0
Get the leaf blocks of under a specific loop.
virtual StmtSRef GetSRef(const LoopRV &loop_rv) const =0
Get the loop sref corresponding to the specific LoopRV.
virtual BlockRV DecomposeReduction(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a reduction block into two separate blocks. a) The init block, which is translated from the...
virtual LoopRV AddUnitLoop(const BlockRV &block_rv)=0
Create a new unit loop on top of the specific block.
virtual LoopRV Merge(const Array< LoopRV > &loop_rvs)=0
Merge a list of loops into one. The loops under their LCA requires: 1) Under the same scope 2) Can't ...
virtual void PadEinsum(const BlockRV &block_rv, const Array< Integer > &padding)=0
Pad the computation of Einsum.
virtual PrimExpr Get(const ExprRV &expr_rv) const =0
Get the expr corresponding to the specific random variable.
virtual Array< BlockRV > CacheIndex(const BlockRV &block_rv, const String &storage_scope, int cse_thresh)=0
Create a block to cache precomputed index for later use. if there is no index computation,...
virtual Array< LoopRV > GetLoops(const BlockRV &block_rv)=0
Get the parent loops of the block in its scope, from outer to inner.
virtual void EnterPostproc()=0
A no-op that marks the start of postprocessing phase of scheduling.
virtual BlockRV ReindexCacheRead(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope, const IndexMap &index_map)=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:201
virtual Array< ExprRV > SamplePerfectTile(const LoopRV &loop_rv, int n, int max_innermost_factor, Optional< Array< Integer >> decision=NullOpt)=0
Sample the factors to perfect tile a specific loop.
virtual BlockRV ReadAt(const LoopRV &loop_rv, const BlockRV &block_rv, int read_buffer_index, const String &storage_scope)=0
virtual LoopRV AddUnitLoop(const LoopRV &loop_rv)=0
Create a new unit loop on top of the specific loop.
virtual StmtSRef GetSRef(const StmtNode *stmt) const
Get the block/loop sref corresponding to the specific statement.
virtual Array< LoopRV > LoopPartition(const LoopRV &loop_rv, const Array< Optional< ExprRV >> &factors, bool preserve_unit_iters=true)=0
Partition the loops into sequence of multiple loops 1) The loop can't have annotation or thread bindi...
virtual StmtSRef GetSRef(const BlockRV &block_rv) const =0
Get the block sref corresponding to the specific BlockRV.
virtual void ReorderBlockIterVar(const BlockRV &block_rv, const Array< Integer > new_order)=0
Reorder the itervars inside a block.
virtual void Seed(support::LinearCongruentialEngine::TRandState seed)=0
Seed the randomness.
virtual Optional< GlobalVar > func_working_on() const =0
virtual Array< BlockRV > GetProducers(const BlockRV &block_rv)=0
Get the producer of a specific block, under the same block scope.
virtual void Unannotate(const LoopRV &loop_rv, const String &ann_key)=0
Unannotate a loop's annotation with key ann_key.
virtual void Parallel(const LoopRV &loop_rv)=0
Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pip...
virtual BlockRV CacheRead(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope, const Array< BlockRV > consumer_blocks={})=0
Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one blo...
virtual LoopRV Fuse(const Array< LoopRV > &loop_rvs, bool preserve_unit_iters=true)=0
Fuse a list of consecutive loops into one. It requires: 1) The loops can't have annotations or thread...
virtual void ComputeInline(const BlockRV &block)=0
Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block,...
virtual void Tensorize(const LoopRV &loop_rv, const String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
virtual Array< ExprRV > SamplePartitionedTile(const LoopRV &loop_rv, int n, int partition_pos, int innerpart_factor, Optional< Array< Integer >> decision=NullOpt)=0
Sample the factors to a partitioned tile for a specific loop.
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:113
virtual void RemoveRV(const BlockRV &block_rv)=0
Remove a block random variable from the symbol table.
virtual Schedule Copy()=0
Returns a copy of the schedule, including both its state and its symbol table, guaranteeing that 1) S...
virtual void Unannotate(const BlockRV &block_rv, const String &ann_key)=0
Unannotate a block's annotation with key ann_key.
virtual void RemoveRV(const LoopRV &loop_rv)=0
Remove a loop random variable from the symbol table.
static constexpr const char * _type_key
Definition: schedule.h:108
virtual void Bind(const LoopRV &loop_rv, const String &thread_axis)=0
Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in sho...
virtual void Annotate(const BlockRV &block_rv, const String &ann_key, const ObjectRef &ann_val)=0
Annotate a block with a key value pair.
virtual void Unroll(const LoopRV &loop_rv)=0
Unroll the input loop. It requires nothing.
virtual void AnnotateBufferAccess(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map)=0
Annotate the buffer access of a block.
virtual void StorageAlign(const BlockRV &block_rv, int buffer_index, int axis, int factor, int offset)=0
Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for so...
virtual Optional< Trace > trace() const =0
virtual support::LinearCongruentialEngine::TRandState ForkSeed()=0
Fork the random state.
virtual Array< BlockRV > GetOutputBlocks(const BlockRV &scope_block_rv)=0
Get the list of output blocks within the given scope An output block is a block which has atleast one...
virtual void TransformBlockLayout(const BlockRV &block_rv, const IndexMap &index_map)=0
Apply a transformation represented by IndexMap to block.
virtual void ReverseComputeInline(const BlockRV &block)=0
Inline a block into its only producer. It requires: 1) The block is a complete non-root block,...
virtual BlockRV ReIndex(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type)=0
Create a block that read/write a buffer region into a read/write cache with reindexing....
virtual BlockRV CacheWrite(const BlockRV &block_rv, int write_buffer_index, const String &storage_scope, const Array< BlockRV > consumer_blocks={})=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void SetScope(const BlockRV &block_rv, int buffer_index, const String &storage_scope)=0
Set the storage scope of a buffer, where the buffer is specified by a block and a write-index.
virtual void Annotate(const LoopRV &loop_rv, const String &ann_key, const ObjectRef &ann_val)=0
Annotate a loop with a key value pair.
virtual BlockRV RFactor(const LoopRV &loop_rv, int factor_axis)=0
Factorize an associative reduction block by the specified loop.
virtual void RollingBuffer(const BlockRV &block_rv, int write_buffer_index)=0
Compute the target buffer via rolling buffering.
virtual Array< LoopRV > Split(const LoopRV &loop_rv, const Array< Optional< ExprRV >> &factors, bool preserve_unit_iters=true, bool disable_predication=false)=0
Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thre...
virtual Array< BlockRV > GetChildBlocks(const BlockRV &block_rv)=0
Get the leaf blocks of a specific scope.
virtual void Vectorize(const LoopRV &loop_rv)=0
Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipel...
virtual ScheduleState state() const =0
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object)
virtual Array< BlockRV > CacheInplace(const BlockRV &block_rv, int read_buffer_index, const String &storage_scope)=0
Create 2 blocks that read&write a buffer region into a read/write cache. It requires the target block...
virtual LoopRV SampleComputeLocation(const BlockRV &block_rv, Optional< Integer > decision=NullOpt)=0
Sample a compute-at location of the given block.
virtual Array< BlockRV > GetConsumers(const BlockRV &block_rv)=0
Get the consumers of a specific block, under the same block scope.
virtual BlockRV GetBlock(const String &name, const Optional< String > &func_name=NullOpt)=0
Retrieve a block in a specific function with its name.
virtual void UnsafeSetDType(const BlockRV &block_rv, int buffer_index, const String &dtype)=0
Set the data type of a buffer, where the buffer is specified by a block and a write-index.
virtual bool HasBlock(const BlockRV &block_rv) const =0
Check the existance of a specific BlockRV.
virtual void WorkOn(const String &func_name)=0
Instruct the schedule to work on a function in the IRModule.
virtual ExprRV SampleCategorical(const Array< runtime::Int > &candidates, const Array< runtime::Float > &probs, Optional< runtime::Int > decision=NullOpt)=0
Sample an integer given the probability distribution.
virtual BlockRV Blockize(const LoopRV &loop_rv, bool preserve_unit_iters=true)=0
Convert the subtree rooted at a specific loop into a block.
virtual BlockRV WriteAt(const LoopRV &loop_rv, const BlockRV &block_rv, int write_buffer_index, const String &storage_scope)=0
virtual void ReverseComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a consumer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual BlockRV Blockize(const Array< BlockRV > &blocks, bool preserve_unit_iters=true)=0
Convert specified blocks into a nested block.
virtual For Get(const LoopRV &loop_rv) const =0
Get the for loop corresponding to the specific LoopRV.
virtual ~ScheduleNode()=default
virtual void UnsafeHideBufferAccess(const BlockRV &block_rv, const String &buf_type, const Array< IntImm > &buf_index_array)=0
Hide some buffer access in the given block.
virtual Block Get(const BlockRV &block_rv) const =0
Get the block corresponding to the specific BlockRV.
virtual BlockRV ReindexCacheWrite(const BlockRV &block_rv, int write_buffer_index, const String &storage_scope, const IndexMap &index_map)=0
Create a block that writes a buffer region into a write cache. It requires: 1) There is only one bloc...
virtual void TransformLayout(const BlockRV &block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap &index_map, const Optional< IndexMap > &pad_value=NullOpt, bool assume_injective_transform=false)=0
Apply a transformation represented by IndexMap to buffer.
virtual BlockRV DecomposePadding(const BlockRV &block_rv, const LoopRV &loop_rv)=0
Decompose a padding block into a block filling const pad values and a block writing in-bound values.
virtual void ComputeAt(const BlockRV &block_rv, const LoopRV &loop_rv, bool preserve_unit_loops, int index=-1)=0
Move a producer block under the specific loop, and regenerate the loops induced by the block so that ...
virtual void Tensorize(const BlockRV &block_rv, const String &intrin, bool preserve_unit_iters=true)=0
Tensorize the computation enclosed by loop with the tensor intrin.
Managed reference to ScheduleStateNode.
Definition: state.h:208
Managed reference to ScheduleNode.
Definition: schedule.h:877
static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a concrete TensorIR schedule from an IRModule.
static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check=true)
Construct a traced concrete TensorIR schedule from an IRModule.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode)
Base node of all statements.
Definition: stmt.h:38
Managed reference to StmtSRefNode.
Definition: block_scope.h:107
Container of all statements.
Definition: stmt.h:59
Defines a remapping of buffer indices.
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
BufferIndexType
Type of buffer index.
Definition: schedule.h:41
@ kRead
Index of a read buffer.
@ kWrite
Index of a written buffer.
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:31
@ kNone
No error message at all.
@ kDetail
Render a detailed error message.
@ kFast
Render the error in fast mode.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
This file defines ScheduleState, the core data structure of TensorIR scheduling.