24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
30 #include <type_traits>
51 static constexpr
const char*
_type_key =
"tir.Stmt";
77 v->Visit(
"var", &
var);
78 v->Visit(
"value", &
value);
79 v->Visit(
"body", &
body);
80 v->Visit(
"span", &
span);
94 static constexpr
const char*
_type_key =
"tir.LetStmt";
132 v->Visit(
"node", &
node);
134 v->Visit(
"value", &
value);
135 v->Visit(
"body", &
body);
136 v->Visit(
"span", &
span);
151 static constexpr
const char*
_type_key =
"tir.AttrStmt";
185 v->Visit(
"body", &
body);
186 v->Visit(
"span", &
span);
200 static constexpr
const char*
_type_key =
"tir.AssertStmt";
238 v->Visit(
"buffer", &
buffer);
239 v->Visit(
"value", &
value);
242 v->Visit(
"span", &
span);
257 static constexpr
const char*
_type_key =
"tir.BufferStore";
297 v->Visit(
"buffer", &
buffer);
298 v->Visit(
"bounds", &
bounds);
300 v->Visit(
"body", &
body);
301 v->Visit(
"span", &
span);
321 static constexpr
const char*
_type_key =
"tir.BufferRealize";
359 v->Visit(
"value", &
value);
361 v->Visit(
"span", &
span);
375 static constexpr
const char*
_type_key =
"tir.ProducerStore";
418 v->Visit(
"bounds", &
bounds);
420 v->Visit(
"body", &
body);
422 v->Visit(
"span", &
span);
439 static constexpr
const char*
_type_key =
"tir.ProducerRealize";
481 v->Visit(
"dtype", &
dtype);
484 v->Visit(
"body", &
body);
486 v->Visit(
"span", &
span);
518 static constexpr
const char*
_type_key =
"tir.Allocate";
569 v->Visit(
"data", &
data);
571 v->Visit(
"dtype", &
dtype);
573 v->Visit(
"body", &
body);
575 v->Visit(
"span", &
span);
607 static constexpr
const char*
_type_key =
"tir.AllocateConst";
640 v->Visit(
"buffer", &
buffer);
641 v->Visit(
"body", &
body);
642 v->Visit(
"span", &
span);
654 static constexpr
const char*
_type_key =
"tir.DeclBuffer";
683 v->Visit(
"seq", &
seq);
684 v->Visit(
"span", &
span);
709 v->Visit(
"value", &
value);
710 v->Visit(
"span", &
span);
719 static constexpr
const char*
_type_key =
"tir.Evaluate";
773 template <
typename... Args>
776 runtime::detail::for_each(
Flattener(&seq), std::forward<Args>(seq_args)...);
780 }
else if (seq.
size() == 1) {
787 if constexpr (
sizeof...(seq_args) == 1) {
789 SeqStmt original = opt.value();
790 bool all_same = [&]() {
791 if (original->seq.
size() != seq.
size()) {
794 for (
size_t i = 0; i < seq.
size(); i++) {
795 if (!original->seq[i].
same_as(seq[i])) {
814 template <
typename T>
816 if constexpr (std::is_same_v<T, SeqStmt>) {
818 }
else if constexpr (!std::is_base_of_v<T, SeqStmt>) {
820 }
else if (
auto* ptr = t.template as<SeqStmtNode>()) {
821 return GetRef<SeqStmt>(ptr);
827 template <
typename T>
829 if constexpr (std::is_base_of_v<ObjectRef, T>) {
831 if (!stmt_or_seq.defined()) {
836 if constexpr (std::is_same_v<T, SeqStmt>) {
838 (*this)(0, stmt_or_seq->seq);
842 if constexpr (std::is_base_of_v<T, SeqStmt>) {
845 if (
auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
851 if constexpr (std::is_base_of_v<T, Evaluate>) {
856 if (
auto* op = stmt_or_seq.template as<EvaluateNode>()) {
857 if (
auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
863 if constexpr (std::is_base_of_v<Stmt, T>) {
865 seq_->push_back(stmt_or_seq);
868 for (
auto v : stmt_or_seq) {
898 v->Visit(
"span", &
span);
912 static constexpr
const char*
_type_key =
"tir.IfThenElse";
996 v->Visit(
"min", &
min);
997 v->Visit(
"extent", &
extent);
998 v->Visit(
"kind", &
kind);
999 v->Visit(
"body", &
body);
1002 v->Visit(
"span", &
span);
1058 v->Visit(
"body", &
body);
1059 v->Visit(
"span", &
span);
1098 v->Visit(
"buffer", &
buffer);
1099 v->Visit(
"bounds", &
bounds);
1100 v->Visit(
"span", &
span);
1143 v->Visit(
"buffer", &
buffer);
1144 v->Visit(
"region", &
region);
1156 static constexpr
const char*
_type_key =
"tir.BufferRegion";
1206 v->Visit(
"buffer", &
buffer);
1207 v->Visit(
"source", &
source);
1219 static constexpr
const char*
_type_key =
"tir.MatchBufferRegion";
1287 v->Visit(
"reads", &
reads);
1288 v->Visit(
"writes", &
writes);
1290 v->Visit(
"body", &
body);
1291 v->Visit(
"init", &
init);
1357 v->Visit(
"block", &
block);
1371 static constexpr
const char*
_type_key =
"tir.BlockRealize";
1595 "meta_schedule.thread_extent_low_inclusive";
1599 "meta_schedule.thread_extent_high_inclusive";
1603 "meta_schedule.random_compute_producer";
1683 return attr_key.compare(0, 7,
"pragma_") == 0;
1706 return "vectorized";
1710 return "thread_binding";
1712 LOG(FATAL) <<
"Unknown ForKind" << t;
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool empty() const
Definition: array.h:432
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
const Object * operator->() const
Definition: object.h:556
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Allocate a buffer that can be used in body.
Definition: stmt.h:541
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:565
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:552
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:556
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:609
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:584
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:558
DataType dtype
The type of the buffer.
Definition: stmt.h:554
Var buffer_var
The buffer variable.
Definition: stmt.h:544
static constexpr const char * _type_key
Definition: stmt.h:607
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:578
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:598
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:608
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:567
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:547
Managed reference to AllocateConstNode.
Definition: stmt.h:617
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode)
AllocateConst(Var buffer_var, DataType dtype, Array< PrimExpr > extents, ObjectRef data_or_idx, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode)
Allocate a buffer that can be used in body.
Definition: stmt.h:459
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:489
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:466
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:477
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:468
Stmt body
The body to be executed.
Definition: stmt.h:470
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode)
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:509
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:520
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:519
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:479
DataType dtype
The type of the buffer.
Definition: stmt.h:464
static constexpr const char * _type_key
Definition: stmt.h:518
Var buffer_var
The buffer variable.
Definition: stmt.h:462
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:495
Managed reference to AllocateNode.
Definition: stmt.h:528
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
Allocate(Var buffer_var, DataType dtype, Array< PrimExpr > extents, PrimExpr condition, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:170
PrimExpr condition
Condition to be checked.
Definition: stmt.h:173
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:175
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:189
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:200
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:194
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:182
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:180
Managed reference to AssertStmtNode.
Definition: stmt.h:208
TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:120
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:127
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:144
Stmt body
The body statement to be executed.
Definition: stmt.h:129
String attr_key
the type key of the attribute
Definition: stmt.h:125
ObjectRef node
this is attribute about certain node
Definition: stmt.h:123
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:131
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode)
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:139
static constexpr const char * _type_key
Definition: stmt.h:151
Managed reference to AttrStmtNode.
Definition: stmt.h:159
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:1258
static constexpr const char * _type_key
Definition: stmt.h:1317
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1263
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1281
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1261
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1265
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1283
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:1277
String name_hint
The name_hint of the block.
Definition: stmt.h:1267
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1285
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1279
Stmt body
The body of the block.
Definition: stmt.h:1269
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1306
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1297
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode)
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1342
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1365
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1360
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:1371
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1345
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:1350
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1354
Block block
The block to be realized.
Definition: stmt.h:1352
Managed reference to BlockRealizeNode.
Definition: stmt.h:1379
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode)
BlockRealize(Array< PrimExpr > iter_values, PrimExpr predicate, Block block, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode)
Managed reference to BlockNode.
Definition: stmt.h:1325
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode)
Block(Array< IterVar > iter_vars, Array< BufferRegion > reads, Array< BufferRegion > writes, String name_hint, Stmt body, Optional< Stmt > init=NullOpt, Array< Buffer > alloc_buffers=Array< Buffer >(), Array< MatchBufferRegion > match_buffers=Array< MatchBufferRegion >(), Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode)
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:285
BufferRealizeNode()=default
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:296
Buffer buffer
The buffer variable.
Definition: stmt.h:288
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:304
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:317
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:292
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:309
Stmt body
The body of realization.
Definition: stmt.h:294
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:290
static constexpr const char * _type_key
Definition: stmt.h:321
Managed reference to BufferRealizeNode.
Definition: stmt.h:329
BufferRealize(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1135
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1138
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1147
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1158
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1142
static constexpr const char * _type_key
Definition: stmt.h:1156
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1157
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1140
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1151
Managed reference to BufferRegionNode.
Definition: stmt.h:1166
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
BufferRegion(Buffer buffer, Array< Range > region)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:226
Buffer buffer
The buffer variable.
Definition: stmt.h:229
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:233
Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:235
PrimExpr value
The value to be stored.
Definition: stmt.h:231
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:250
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:237
static constexpr const char * _type_key
Definition: stmt.h:257
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:245
Managed reference to BufferStoreNode.
Definition: stmt.h:265
BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices, Optional< PrimExpr > predicate=NullOpt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
Managed reference to DataProducerNode.
Definition: buffer.h:313
Declare a buffer that can be used in the body.
Definition: stmt.h:632
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:649
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:639
static constexpr const char * _type_key
Definition: stmt.h:654
Buffer buffer
The buffer being declared.
Definition: stmt.h:635
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:637
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:645
Managed reference to DeclBufferNode.
Definition: stmt.h:659
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:703
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:717
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:708
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:719
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:706
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:713
Managed reference to EvaluateNode.
Definition: stmt.h:727
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:731
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode)
Evaluate(PrimExpr value, Span span=Span())
A for loop, with possible type annotations.
Definition: stmt.h:967
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:994
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:983
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:972
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1011
ForKind kind
The kind of the for loop.
Definition: stmt.h:976
static constexpr const char * _type_key
Definition: stmt.h:1021
Var loop_var
The loop variable.
Definition: stmt.h:970
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:974
Stmt body
The body of the for loop.
Definition: stmt.h:978
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode)
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:1005
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:992
Managed reference to ForNode.
Definition: stmt.h:1029
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional< IterVar > thread_binding=NullOpt, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:885
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:894
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:906
PrimExpr condition
The condition.
Definition: stmt.h:888
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:892
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:901
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:912
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:890
Managed reference to IfThenElseNode.
Definition: stmt.h:920
IfThenElse(PrimExpr condition, Stmt then_case, Optional< Stmt > else_case=NullOpt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode)
Let binding, bind var to value, then run body.
Definition: stmt.h:67
PrimExpr value
The value to be bound.
Definition: stmt.h:72
static constexpr const char * _type_key
Definition: stmt.h:94
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:76
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:83
Stmt body
The body block.
Definition: stmt.h:74
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode)
Var var
The variable.
Definition: stmt.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:88
Managed reference to LetStmtNode.
Definition: stmt.h:102
TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1198
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object)
static constexpr const char * _type_key
Definition: stmt.h:1219
Buffer buffer
The target buffer.
Definition: stmt.h:1201
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1221
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1205
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1210
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1220
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1214
BufferRegion source
The source buffer region.
Definition: stmt.h:1203
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1229
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
A prefetch hint for a buffer.
Definition: stmt.h:1090
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode)
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1097
static constexpr const char * _type_key
Definition: stmt.h:1116
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1113
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1095
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1107
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1093
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1103
Managed reference to PrefetchNode.
Definition: stmt.h:1124
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode)
Prefetch(Buffer buffer, Array< Range > bounds, Span span=Span())
Annotate the bounds where the data produced by the producer need to be written and read in body....
Definition: stmt.h:403
Stmt body
The body of realization.
Definition: stmt.h:412
DataProducer producer
The producer that produces the data.
Definition: stmt.h:406
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:410
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:431
static constexpr const char * _type_key
Definition: stmt.h:439
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:416
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:414
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:425
Region bounds
Bounds to be realized.
Definition: stmt.h:408
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode)
Managed reference to ProducerRealizeNode.
Definition: stmt.h:447
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode)
ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope="", Span span=Span())
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:348
PrimExpr value
The value to be stored.
Definition: stmt.h:353
DataProducer producer
The producer to store the results into.
Definition: stmt.h:351
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:369
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode)
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:355
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:357
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:364
static constexpr const char * _type_key
Definition: stmt.h:375
Managed reference to ProducerStoreNode.
Definition: stmt.h:383
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode)
ProducerStore(DataProducer producer, PrimExpr value, Array< PrimExpr > indices, Span span=Span())
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:670
size_t size() const
Definition: stmt.h:676
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:673
static constexpr const char * _type_key
Definition: stmt.h:693
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:682
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:687
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:691
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:680
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:810
Flattener(Array< Stmt > *seq)
Definition: stmt.h:812
static Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:815
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:828
Sequence statement.
Definition: stmt.h:738
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:748
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:752
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:774
SeqStmt(Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Base node of all statements.
Definition: stmt.h:38
TVM_OBJECT_ENABLE_SCRIPT_PRINTER()
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:54
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:52
StmtNode(Span span)
Definition: stmt.h:47
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:53
static constexpr const char * _type_key
Definition: stmt.h:51
Container of all statements.
Definition: stmt.h:59
TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:89
A While loop.
Definition: stmt.h:1049
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1056
static constexpr const char * _type_key
Definition: stmt.h:1071
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:1054
PrimExpr condition
The termination condition.
Definition: stmt.h:1052
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1066
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:1062
Managed reference to WhileNode.
Definition: stmt.h:1079
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
tvm::Span Span
Definition: base.h:65
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1414
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1491
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1570
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1612
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1550
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1479
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1496
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1403
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1618
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1502
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1582
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1434
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1430
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1585
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1567
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1602
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1662
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1509
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1682
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1392
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1559
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1633
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1394
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1409
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1641
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1426
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1615
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1498
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1650
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1468
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1533
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1420
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1594
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1621
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1656
constexpr const char * device_type
The device type.
Definition: stmt.h:1422
constexpr const char * explicit_write_region
Mark that a block has an explicitly specified write region. This is used to override the default writ...
Definition: stmt.h:1675
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1644
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1576
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1625
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1591
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1470
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1428
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1424
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1462
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1540
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1438
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1545
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1579
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1534
constexpr const char * explicit_read_region
Mark that a block has an explicitly specified read region. This is used to override the default read ...
Definition: stmt.h:1670
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1647
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:1396
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1481
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1443
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1535
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1450
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1418
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1494
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1598
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1500
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:1401
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1665
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1564
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1504
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1436
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1432
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1472
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1630
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1416
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1653
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1609
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1466
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1514
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1606
const char * ForKind2String(ForKind t)
Definition: stmt.h:1699
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:936
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:244
@ kUnrolled
The execution is unrolled.
Definition: var.h:240
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169