tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tir/expr.h>
30 
31 #include <string>
32 #include <type_traits>
33 #include <utility>
34 
35 namespace tvm {
36 namespace tir {
37 
39 class StmtNode : public Object {
40  public:
45  mutable Span span;
46 
47  StmtNode() = default;
48  explicit StmtNode(Span span) : span(span) {}
49 
50  static void RegisterReflection() {
51  namespace refl = tvm::ffi::reflection;
52  refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
53  }
54 
56 
57  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
58 
59  static constexpr const uint32_t _type_child_slots = 15;
60  TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object);
61 };
62 
64 class Stmt : public ObjectRef {
65  public:
67 };
68 
72 class LetStmtNode : public StmtNode {
73  public:
80 
81  static void RegisterReflection() {
82  namespace refl = tvm::ffi::reflection;
83  refl::ObjectDef<LetStmtNode>()
84  .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef())
85  .def_ro("value", &LetStmtNode::value)
86  .def_ro("body", &LetStmtNode::body);
87  }
89 };
90 
95 class LetStmt : public Stmt {
96  public:
97  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
98 
101 };
102 
113 class AttrStmtNode : public StmtNode {
114  public:
116  ffi::Any node;
118  ffi::String attr_key;
123 
124  static void RegisterReflection() {
125  namespace refl = tvm::ffi::reflection;
126  refl::ObjectDef<AttrStmtNode>()
127  .def_ro("node", &AttrStmtNode::node)
128  .def_ro("attr_key", &AttrStmtNode::attr_key)
129  .def_ro("value", &AttrStmtNode::value)
130  .def_ro("body", &AttrStmtNode::body);
131  }
133 };
134 
139 class AttrStmt : public Stmt {
140  public:
141  TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body,
142  Span span = Span());
143 
146 };
147 
151 class AssertStmtNode : public StmtNode {
152  public:
162 
163  static void RegisterReflection() {
164  namespace refl = tvm::ffi::reflection;
165  refl::ObjectDef<AssertStmtNode>()
166  .def_ro("condition", &AssertStmtNode::condition)
167  .def_ro("message", &AssertStmtNode::message)
168  .def_ro("body", &AssertStmtNode::body);
169  }
171 };
172 
177 class AssertStmt : public Stmt {
178  public:
179  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
180 
183 };
184 
195 class BufferStoreNode : public StmtNode {
196  public:
202  ffi::Array<PrimExpr> indices;
204  ffi::Optional<PrimExpr> predicate;
205 
206  static void RegisterReflection() {
207  namespace refl = tvm::ffi::reflection;
208  refl::ObjectDef<BufferStoreNode>()
209  .def_ro("buffer", &BufferStoreNode::buffer)
210  .def_ro("value", &BufferStoreNode::value)
211  .def_ro("indices", &BufferStoreNode::indices)
212  .def_ro("predicate", &BufferStoreNode::predicate);
213  }
215 };
216 
221 class BufferStore : public Stmt {
222  public:
223  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
224  ffi::Optional<PrimExpr> predicate = std::nullopt,
225  Span span = Span());
226 
229 };
230 
234 class AllocateNode : public StmtNode {
235  public:
241  ffi::Array<PrimExpr> extents;
252  ffi::Map<ffi::String, ffi::Any> annotations;
253 
254  static void RegisterReflection() {
255  namespace refl = tvm::ffi::reflection;
256  refl::ObjectDef<AllocateNode>()
257  .def_ro("buffer_var", &AllocateNode::buffer_var, refl::AttachFieldFlag::SEqHashDef())
258  .def_ro("dtype", &AllocateNode::dtype)
259  .def_ro("extents", &AllocateNode::extents)
260  .def_ro("condition", &AllocateNode::condition)
261  .def_ro("body", &AllocateNode::body)
262  .def_ro("annotations", &AllocateNode::annotations);
263  }
264 
277  TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array<PrimExpr>& extents);
279 };
280 
285 class Allocate : public Stmt {
286  public:
287  TVM_DLL Allocate(Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, PrimExpr condition,
288  Stmt body,
289  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
290  Span span = Span());
291 
294 };
295 
297 class DeclBufferNode : public StmtNode {
298  public:
303 
304  static void RegisterReflection() {
305  namespace refl = tvm::ffi::reflection;
306  refl::ObjectDef<DeclBufferNode>()
307  .def_ro("buffer", &DeclBufferNode::buffer)
308  .def_ro("body", &DeclBufferNode::body);
309  }
311 };
312 
314 class DeclBuffer : public Stmt {
315  public:
316  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
319 };
320 
325 class SeqStmtNode : public StmtNode {
326  public:
328  ffi::Array<Stmt> seq;
329 
331  size_t size() const { return seq.size(); }
335  Stmt operator[](size_t index) const { return seq[index]; }
336 
337  static void RegisterReflection() {
338  namespace refl = tvm::ffi::reflection;
339  refl::ObjectDef<SeqStmtNode>().def_ro("seq", &SeqStmtNode::seq);
340  }
342 };
343 
350 class EvaluateNode : public StmtNode {
351  public:
354 
355  static void RegisterReflection() {
356  namespace refl = tvm::ffi::reflection;
357  refl::ObjectDef<EvaluateNode>().def_ro("value", &EvaluateNode::value);
358  }
360 };
361 
366 class Evaluate : public Stmt {
367  public:
368  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
369 
370  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
371 
374 };
375 
377 class SeqStmt : public Stmt {
378  public:
384  TVM_DLL explicit SeqStmt(ffi::Array<Stmt> seq, Span span = Span());
385 
387  size_t size() const { return operator->()->size(); }
391  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
412  template <typename... Args>
413  static Stmt Flatten(Args&&... seq_args) {
414  ffi::Array<Stmt> seq;
415 
416  ffi::details::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
417 
418  if (seq.empty()) {
419  return Evaluate(0);
420  } else if (seq.size() == 1) {
421  return seq[0];
422  }
423 
424  // If the argument is a single SeqStmt argument with no
425  // flattening or unwrapping required, then we may
426  // return the SeqStmt as-is.
427  if constexpr (sizeof...(seq_args) == 1) {
428  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
429  SeqStmt original = opt.value();
430  bool all_same = [&]() {
431  if (original->seq.size() != seq.size()) {
432  return false;
433  }
434  for (size_t i = 0; i < seq.size(); i++) {
435  if (!original->seq[i].same_as(seq[i])) {
436  return false;
437  }
438  }
439  return true;
440  }();
441  if (all_same) {
442  return original;
443  }
444  }
445  }
446 
447  return SeqStmt(seq);
448  }
450  class Flattener {
451  public:
452  explicit Flattener(ffi::Array<Stmt>* seq) : seq_(seq) {}
453 
454  template <typename T>
455  static ffi::Optional<SeqStmt> AsSeqStmt(const T& t) {
456  if constexpr (std::is_same_v<T, SeqStmt>) {
457  return t;
458  }
459  if constexpr (!std::is_base_of_v<T, SeqStmt>) {
460  return std::nullopt;
461  }
462  if constexpr (std::is_base_of_v<Stmt, T>) {
463  if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) {
464  return ffi::GetRef<SeqStmt>(ptr);
465  } else {
466  return std::nullopt;
467  }
468  }
469  return std::nullopt;
470  }
471 
472  template <typename T>
473  void operator()(size_t i, const T& stmt_or_seq) const {
474  if constexpr (std::is_base_of_v<ObjectRef, T>) {
475  // Early bail-out, applicable to any ObjectRef
476  if (!stmt_or_seq.defined()) {
477  return;
478  }
479  }
480 
481  if constexpr (std::is_same_v<T, SeqStmt>) {
482  // Static type-checking for a SeqStmt that could be flattened.
483  (*this)(0, stmt_or_seq->seq);
484  return;
485  }
486 
487  if constexpr (std::is_base_of_v<T, SeqStmt>) {
488  // Dynamic type-checking for a SeqStmt that could be
489  // flattened.
490  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
491  operator()(0, op->seq);
492  return;
493  }
494  }
495 
496  if constexpr (std::is_base_of_v<T, Evaluate>) {
497  // Evaluate(0) is used to represent a no-op, and may be
498  // generated by previous calls to SeqStmt::Flatten(). These
499  // should be removed to ensure that Flatten(a+b) is equivalent
500  // to Flatten(Flatten(a), Flatten(b)).
501  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
502  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
503  return;
504  }
505  }
506  }
507 
508  if constexpr (std::is_base_of_v<Stmt, T>) {
509  // Any other Stmt type just gets appended.
510  seq_->push_back(stmt_or_seq);
511  } else {
512  // Anything else is treated as an iterable of Stmt.
513  for (auto v : stmt_or_seq) {
514  this->operator()(0, v);
515  }
516  }
517  }
518 
519  private:
520  ffi::Array<Stmt>* seq_;
521  };
522 
525 };
526 
530 class IfThenElseNode : public StmtNode {
531  public:
537  ffi::Optional<Stmt> else_case;
538 
539  static void RegisterReflection() {
540  namespace refl = tvm::ffi::reflection;
541  refl::ObjectDef<IfThenElseNode>()
542  .def_ro("condition", &IfThenElseNode::condition)
543  .def_ro("then_case", &IfThenElseNode::then_case)
544  .def_ro("else_case", &IfThenElseNode::else_case);
545  }
547 };
548 
553 class IfThenElse : public Stmt {
554  public:
555  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case,
556  ffi::Optional<Stmt> else_case = std::nullopt, Span span = Span());
557 
560 };
561 
569 enum class ForKind : int {
571  kSerial = 0,
573  kParallel = 1,
578  kVectorized = 2,
580  kUnrolled = 3,
587  kThreadBinding = 4
588 };
589 
600 class ForNode : public StmtNode {
601  public:
616  ffi::Optional<IterVar> thread_binding;
625  ffi::Map<ffi::String, ffi::Any> annotations;
629  ffi::Optional<PrimExpr> step;
630 
631  static void RegisterReflection() {
632  namespace refl = tvm::ffi::reflection;
633  refl::ObjectDef<ForNode>()
634  .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef())
635  .def_ro("min", &ForNode::min)
636  .def_ro("extent", &ForNode::extent)
637  .def_ro("kind", &ForNode::kind)
638  .def_ro("body", &ForNode::body)
639  .def_ro("thread_binding", &ForNode::thread_binding)
640  .def_ro("annotations", &ForNode::annotations)
641  .def_ro("step", &ForNode::step);
642  }
643 
645  bool HasTrivialStep() const;
646 
648 };
649 
654 class For : public Stmt {
655  public:
656  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
657  ffi::Optional<IterVar> thread_binding = std::nullopt,
658  ffi::Map<ffi::String, ffi::Any> annotations = {},
659  ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
660 
663 };
664 
675 class WhileNode : public StmtNode {
676  public:
681 
682  static void RegisterReflection() {
683  namespace refl = tvm::ffi::reflection;
684  refl::ObjectDef<WhileNode>()
685  .def_ro("condition", &WhileNode::condition)
686  .def_ro("body", &WhileNode::body);
687  }
689 };
690 
695 class While : public Stmt {
696  public:
697  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
698 
701 };
702 
707  public:
711  ffi::Array<Range> region;
712 
713  static void RegisterReflection() {
714  namespace refl = tvm::ffi::reflection;
715  refl::ObjectDef<BufferRegionNode>()
716  .def_ro("buffer", &BufferRegionNode::buffer)
717  .def_ro("region", &BufferRegionNode::region);
718  }
719 
720  TVM_DLL PrimExpr ToPrimExpr() const final;
721 
722  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
724 };
725 
731  public:
732  TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array<Range> region);
733 
740 
747  TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array<PrimExpr> indices);
748 
751 };
752 
762 class MatchBufferRegionNode : public Object {
763  public:
768 
769  static void RegisterReflection() {
770  namespace refl = tvm::ffi::reflection;
771  refl::ObjectDef<MatchBufferRegionNode>()
772  .def_ro("buffer", &MatchBufferRegionNode::buffer)
773  .def_ro("source", &MatchBufferRegionNode::source);
774  }
775 
776  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
777  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object);
778 };
779 
784 class MatchBufferRegion : public ObjectRef {
785  public:
786  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
787 
790 };
791 
813 class SBlockNode : public StmtNode {
814  public:
816  ffi::Array<IterVar> iter_vars;
818  ffi::Array<BufferRegion> reads;
820  ffi::Array<BufferRegion> writes;
822  ffi::String name_hint;
824  ffi::Array<Buffer> alloc_buffers;
826  ffi::Array<MatchBufferRegion> match_buffers;
828  ffi::Map<ffi::String, ffi::Any> annotations;
836  ffi::Optional<Stmt> init;
839 
840  static void RegisterReflection() {
841  namespace refl = tvm::ffi::reflection;
842  refl::ObjectDef<SBlockNode>()
843  .def_ro("iter_vars", &SBlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef())
844  .def_ro("reads", &SBlockNode::reads)
845  .def_ro("writes", &SBlockNode::writes)
846  .def_ro("name_hint", &SBlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
847  .def_ro("alloc_buffers", &SBlockNode::alloc_buffers)
848  .def_ro("match_buffers", &SBlockNode::match_buffers)
849  .def_ro("annotations", &SBlockNode::annotations)
850  .def_ro("init", &SBlockNode::init)
851  .def_ro("body", &SBlockNode::body);
852  }
854 };
855 
860 class SBlock : public Stmt {
861  public:
862  TVM_DLL explicit SBlock(
863  ffi::Array<IterVar> iter_vars, ffi::Array<BufferRegion> reads,
864  ffi::Array<BufferRegion> writes, ffi::String name_hint, Stmt body,
865  ffi::Optional<Stmt> init = std::nullopt,
866  ffi::Array<Buffer> alloc_buffers = ffi::Array<Buffer>(),
867  ffi::Array<MatchBufferRegion> match_buffers = ffi::Array<MatchBufferRegion>(),
868  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
869  Span span = Span());
870 
873 };
874 
878 class SBlockRealizeNode : public StmtNode {
879  public:
881  ffi::Array<PrimExpr> iter_values;
889 
890  static void RegisterReflection() {
891  namespace refl = tvm::ffi::reflection;
892  refl::ObjectDef<SBlockRealizeNode>()
893  .def_ro("iter_values", &SBlockRealizeNode::iter_values)
894  .def_ro("predicate", &SBlockRealizeNode::predicate)
895  .def_ro("block", &SBlockRealizeNode::block);
896  }
898 };
899 
904 class SBlockRealize : public Stmt {
905  public:
906  TVM_DLL explicit SBlockRealize(ffi::Array<PrimExpr> iter_values, PrimExpr predicate, SBlock block,
907  Span span = Span());
908 
911 };
912 
914 namespace attr {
915 // The above attr does not pass to ir stage.
917 constexpr const char* thread_extent = "thread_extent";
919 constexpr const char* virtual_thread = "virtual_thread";
921 constexpr const char* coproc_scope = "coproc_scope";
926 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
928 constexpr const char* volatile_scope = "volatile_scope";
934 constexpr const char* extern_scope = "extern_scope";
939 constexpr const char* compute_scope = "compute_scope";
941 constexpr const char* storage_alignment = "storage_alignment";
943 constexpr const char* device_id = "device_id";
945 constexpr const char* device_type = "device_type";
947 constexpr const char* loop_scope = "loop_scope";
949 constexpr const char* reduce_scope = "reduce_scope";
951 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
953 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
955 constexpr const char* pragma_scope_prefix = "pragma_";
957 constexpr const char* pragma_import_c = "pragma_import_c";
959 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
961 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
965 constexpr const char* double_buffer_scope = "double_buffer_scope";
969 constexpr const char* double_buffer_write = "double_buffer_write";
971 constexpr const char* scan_update_scope = "scan_update_scope";
973 constexpr const char* scan_init_scope = "scan_init_scope";
975 constexpr const char* buffer_bound = "buffer_bound";
985 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
986 // Pipeline related attributes
988 constexpr const char* channel_read_scope = "channel_read_scope";
990 constexpr const char* channel_read_advance = "channel_read_advance";
992 constexpr const char* channel_write_scope = "channel_write_scope";
994 constexpr const char* channel_write_advance = "channel_write_advance";
996 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
998 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
999 
1003 constexpr const char* device_scope = "device_scope";
1004 
1008 constexpr const char* async_scope = "async_scope";
1009 
1027 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1028 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1029 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1030 
1034 constexpr const char* fragment_shape = "fragment_shape";
1035 
1039 constexpr const char* fragment_layout = "fragment_layout";
1040 
1044 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1045 
1051 inline bool IsPragmaKey(const std::string& attr_key) {
1052  return attr_key.compare(0, 7, "pragma_") == 0;
1053 }
1054 
1055 } // namespace attr
1062 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1063 
1064 // overload printing of for type.
1065 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1066 
1067 // inline implementations
1068 inline const char* ForKind2String(ForKind t) {
1069  switch (t) {
1070  case ForKind::kSerial:
1071  return "serial";
1072  case ForKind::kParallel:
1073  return "parallel";
1074  case ForKind::kVectorized:
1075  return "vectorized";
1076  case ForKind::kUnrolled:
1077  return "unroll";
1079  return "thread_binding";
1080  }
1081  TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t;
1082 }
1083 
1084 } // namespace tir
1085 } // namespace tvm
1086 #endif // TVM_TIR_STMT_H_
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:156
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:167
Reference to PrimExprNode.
Definition: expr.h:126
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
Allocate a buffer that can be used in body.
Definition: stmt.h:234
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Allocate", AllocateNode, StmtNode)
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:252
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:243
Stmt body
The body to be executed.
Definition: stmt.h:245
static int64_t ConstantAllocationSize(const ffi::Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:270
static void RegisterReflection()
Definition: stmt.h:254
DataType dtype
The type of the buffer.
Definition: stmt.h:239
Var buffer_var
The buffer variable.
Definition: stmt.h:237
ffi::Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:241
Managed reference to AllocateNode.
Definition: stmt.h:285
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Allocate, Stmt, AllocateNode)
Allocate(Var buffer_var, DataType dtype, ffi::Array< PrimExpr > extents, PrimExpr condition, Stmt body, ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:151
PrimExpr condition
Condition to be checked.
Definition: stmt.h:154
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:156
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AssertStmt", AssertStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:163
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:161
Managed reference to AssertStmtNode.
Definition: stmt.h:177
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:113
ffi::Any node
this is attribute about certain node
Definition: stmt.h:116
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:120
Stmt body
The body statement to be executed.
Definition: stmt.h:122
static void RegisterReflection()
Definition: stmt.h:124
ffi::String attr_key
the type key of the attribute
Definition: stmt.h:118
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AttrStmt", AttrStmtNode, StmtNode)
Managed reference to AttrStmtNode.
Definition: stmt.h:139
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:706
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRegion", BufferRegionNode, PrimExprConvertibleNode)
PrimExpr ToPrimExpr() const final
static void RegisterReflection()
Definition: stmt.h:713
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:709
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:722
ffi::Array< Range > region
The region array of the buffer region.
Definition: stmt.h:711
Managed reference to BufferRegionNode.
Definition: stmt.h:730
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, ffi::Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferRegion, PrimExprConvertible, BufferRegionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
BufferRegion(Buffer buffer, ffi::Array< Range > region)
Store value to the high dimension buffer.
Definition: stmt.h:195
Buffer buffer
The buffer variable.
Definition: stmt.h:198
ffi::Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:202
PrimExpr value
The value to be stored.
Definition: stmt.h:200
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferStore", BufferStoreNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:206
ffi::Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:204
Managed reference to BufferStoreNode.
Definition: stmt.h:221
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferStore, Stmt, BufferStoreNode)
BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Declare a buffer that can be used in the body.
Definition: stmt.h:297
static void RegisterReflection()
Definition: stmt.h:304
Buffer buffer
The buffer being declared.
Definition: stmt.h:300
Stmt body
The body to be executed.
Definition: stmt.h:302
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.DeclBuffer", DeclBufferNode, StmtNode)
Managed reference to DeclBufferNode.
Definition: stmt.h:314
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:350
static void RegisterReflection()
Definition: stmt.h:355
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:353
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Evaluate", EvaluateNode, StmtNode)
Managed reference to EvaluateNode.
Definition: stmt.h:366
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:370
Evaluate(PrimExpr value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Evaluate, Stmt, EvaluateNode)
A for loop, with possible type annotations.
Definition: stmt.h:600
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:605
ForKind kind
The kind of the for loop.
Definition: stmt.h:609
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:631
Var loop_var
The loop variable.
Definition: stmt.h:603
bool HasTrivialStep() const
Check it is a loop without nontrivial loop step.
ffi::Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:616
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the loop.
Definition: stmt.h:625
ffi::Optional< PrimExpr > step
The loop step. It is one if not specified.
Definition: stmt.h:629
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:607
Stmt body
The body of the for loop.
Definition: stmt.h:611
Managed reference to ForNode.
Definition: stmt.h:654
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional< IterVar > thread_binding=std::nullopt, ffi::Map< ffi::String, ffi::Any > annotations={}, ffi::Optional< PrimExpr > step=std::nullopt, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:530
PrimExpr condition
The condition.
Definition: stmt.h:533
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IfThenElse", IfThenElseNode, StmtNode)
ffi::Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:537
static void RegisterReflection()
Definition: stmt.h:539
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:535
Managed reference to IfThenElseNode.
Definition: stmt.h:553
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IfThenElse, Stmt, IfThenElseNode)
IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional< Stmt > else_case=std::nullopt, Span span=Span())
Let binding, bind var to value, then run body.
Definition: stmt.h:72
PrimExpr value
The value to be bound.
Definition: stmt.h:77
Stmt body
The body block.
Definition: stmt.h:79
static void RegisterReflection()
Definition: stmt.h:81
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode)
Var var
The variable.
Definition: stmt.h:75
Managed reference to LetStmtNode.
Definition: stmt.h:95
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode)
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:762
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object)
Buffer buffer
The target buffer.
Definition: stmt.h:765
static void RegisterReflection()
Definition: stmt.h:769
BufferRegion source
The source buffer region.
Definition: stmt.h:767
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:784
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:813
ffi::Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:824
ffi::Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:816
ffi::Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:818
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlock", SBlockNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:840
ffi::String name_hint
The name_hint of the block.
Definition: stmt.h:822
ffi::Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:836
ffi::Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:826
ffi::Map< ffi::String, ffi::Any > annotations
The annotation of the block.
Definition: stmt.h:828
ffi::Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:820
Stmt body
The body of the block.
Definition: stmt.h:838
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:878
ffi::Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:881
static void RegisterReflection()
Definition: stmt.h:890
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:886
SBlock block
The block to be realized.
Definition: stmt.h:888
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockRealize", SBlockRealizeNode, StmtNode)
Managed reference to BlockRealizeNode.
Definition: stmt.h:904
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlockRealize, Stmt, SBlockRealizeNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode)
SBlockRealize(ffi::Array< PrimExpr > iter_values, PrimExpr predicate, SBlock block, Span span=Span())
Managed reference to SBlockNode.
Definition: stmt.h:860
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockNode)
SBlock(ffi::Array< IterVar > iter_vars, ffi::Array< BufferRegion > reads, ffi::Array< BufferRegion > writes, ffi::String name_hint, Stmt body, ffi::Optional< Stmt > init=std::nullopt, ffi::Array< Buffer > alloc_buffers=ffi::Array< Buffer >(), ffi::Array< MatchBufferRegion > match_buffers=ffi::Array< MatchBufferRegion >(), ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlock, Stmt, SBlockNode)
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:325
size_t size() const
Definition: stmt.h:331
ffi::Array< Stmt > seq
internal sequence content.
Definition: stmt.h:328
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SeqStmt", SeqStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:337
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:335
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:450
Flattener(ffi::Array< Stmt > *seq)
Definition: stmt.h:452
static ffi::Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:455
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:473
Sequence statement.
Definition: stmt.h:377
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:387
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
SeqStmt(ffi::Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:391
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:413
Base node of all statements.
Definition: stmt.h:39
TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object)
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:59
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:57
StmtNode(Span span)
Definition: stmt.h:48
static void RegisterReflection()
Definition: stmt.h:50
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:45
Container of all statements.
Definition: stmt.h:64
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:76
A While loop.
Definition: stmt.h:675
Stmt body
The body of the while loop.
Definition: stmt.h:680
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.While", WhileNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:682
PrimExpr condition
The termination condition.
Definition: stmt.h:678
Managed reference to WhileNode.
Definition: stmt.h:695
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(While, Stmt, WhileNode)
Definition: repr_printer.h:91
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:939
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:985
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:990
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:928
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:996
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:957
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:953
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1003
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1051
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:917
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:919
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:934
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:949
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:992
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1027
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:943
constexpr const char * device_type
The device type.
Definition: stmt.h:945
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:971
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:951
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:947
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:965
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1034
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:961
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1039
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1028
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:921
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:975
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1029
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:988
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:994
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:926
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1044
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:998
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:959
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:955
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:973
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:941
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:969
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1008
const char * ForKind2String(ForKind t)
Definition: stmt.h:1068
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:569
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:236
@ kUnrolled
The execution is unrolled.
Definition: var.h:232
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
TIR expressions.