tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/tir/expr.h>
29 
30 #include <string>
31 #include <type_traits>
32 #include <utility>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
49  static void RegisterReflection() {
50  namespace refl = tvm::ffi::reflection;
51  refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
52  }
53 
55 
56  static constexpr const char* _type_key = "tir.Stmt";
57  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
58 
59  static constexpr const uint32_t _type_child_slots = 15;
61 };
62 
64 class Stmt : public ObjectRef {
65  public:
67 };
68 
72 class LetStmtNode : public StmtNode {
73  public:
80 
81  static void RegisterReflection() {
82  namespace refl = tvm::ffi::reflection;
83  refl::ObjectDef<LetStmtNode>()
84  .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef())
85  .def_ro("value", &LetStmtNode::value)
86  .def_ro("body", &LetStmtNode::body);
87  }
88 
89  static constexpr const char* _type_key = "tir.LetStmt";
91 };
92 
97 class LetStmt : public Stmt {
98  public:
99  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
100 
103 };
104 
115 class AttrStmtNode : public StmtNode {
116  public:
118  ffi::Any node;
120  String attr_key;
125 
126  static void RegisterReflection() {
127  namespace refl = tvm::ffi::reflection;
128  refl::ObjectDef<AttrStmtNode>()
129  .def_ro("node", &AttrStmtNode::node)
130  .def_ro("attr_key", &AttrStmtNode::attr_key)
131  .def_ro("value", &AttrStmtNode::value)
132  .def_ro("body", &AttrStmtNode::body);
133  }
134 
135  static constexpr const char* _type_key = "tir.AttrStmt";
137 };
138 
143 class AttrStmt : public Stmt {
144  public:
145  TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
146 
149 };
150 
154 class AssertStmtNode : public StmtNode {
155  public:
165 
166  static void RegisterReflection() {
167  namespace refl = tvm::ffi::reflection;
168  refl::ObjectDef<AssertStmtNode>()
169  .def_ro("condition", &AssertStmtNode::condition)
170  .def_ro("message", &AssertStmtNode::message)
171  .def_ro("body", &AssertStmtNode::body);
172  }
173 
174  static constexpr const char* _type_key = "tir.AssertStmt";
176 };
177 
182 class AssertStmt : public Stmt {
183  public:
184  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
185 
188 };
189 
200 class BufferStoreNode : public StmtNode {
201  public:
207  Array<PrimExpr> indices;
209  Optional<PrimExpr> predicate;
210 
211  static void RegisterReflection() {
212  namespace refl = tvm::ffi::reflection;
213  refl::ObjectDef<BufferStoreNode>()
214  .def_ro("buffer", &BufferStoreNode::buffer)
215  .def_ro("value", &BufferStoreNode::value)
216  .def_ro("indices", &BufferStoreNode::indices)
217  .def_ro("predicate", &BufferStoreNode::predicate);
218  }
219 
220  static constexpr const char* _type_key = "tir.BufferStore";
222 };
223 
228 class BufferStore : public Stmt {
229  public:
230  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
231  Optional<PrimExpr> predicate = std::nullopt, Span span = Span());
232 
235 };
236 
248 class BufferRealizeNode : public StmtNode {
249  public:
253  Array<Range> bounds;
258 
259  static void RegisterReflection() {
260  namespace refl = tvm::ffi::reflection;
261  refl::ObjectDef<BufferRealizeNode>()
262  .def_ro("buffer", &BufferRealizeNode::buffer)
263  .def_ro("bounds", &BufferRealizeNode::bounds)
264  .def_ro("condition", &BufferRealizeNode::condition)
265  .def_ro("body", &BufferRealizeNode::body);
266  }
267 
268  BufferRealizeNode() = default;
270  Span span = Span())
272 
273  static constexpr const char* _type_key = "tir.BufferRealize";
275 };
276 
281 class BufferRealize : public Stmt {
282  public:
283  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
284  Span span = Span());
285 
288 };
289 
293 class AllocateNode : public StmtNode {
294  public:
300  Array<PrimExpr> extents;
311  Map<String, ffi::Any> annotations;
312 
313  static void RegisterReflection() {
314  namespace refl = tvm::ffi::reflection;
315  refl::ObjectDef<AllocateNode>()
316  .def_ro("buffer_var", &AllocateNode::buffer_var, refl::AttachFieldFlag::SEqHashDef())
317  .def_ro("dtype", &AllocateNode::dtype)
318  .def_ro("extents", &AllocateNode::extents)
319  .def_ro("condition", &AllocateNode::condition)
320  .def_ro("body", &AllocateNode::body)
321  .def_ro("annotations", &AllocateNode::annotations);
322  }
323 
336  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
337 
338  static constexpr const char* _type_key = "tir.Allocate";
339 
341 };
342 
347 class Allocate : public Stmt {
348  public:
349  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
350  Stmt body, Map<String, ffi::Any> annotations = Map<String, ffi::Any>(),
351  Span span = Span());
352 
355 };
356 
360 class AllocateConstNode : public StmtNode {
361  public:
366  Optional<runtime::NDArray> data;
371  Optional<Integer> irmod_storage_idx;
375  Array<PrimExpr> extents;
384  Map<String, ffi::Any> annotations;
385 
386  static void RegisterReflection() {
387  namespace refl = tvm::ffi::reflection;
388  refl::ObjectDef<AllocateConstNode>()
389  .def_ro("buffer_var", &AllocateConstNode::buffer_var, refl::AttachFieldFlag::SEqHashDef())
390  .def_ro("data", &AllocateConstNode::data)
391  .def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx)
392  .def_ro("dtype", &AllocateConstNode::dtype)
393  .def_ro("extents", &AllocateConstNode::extents)
394  .def_ro("body", &AllocateConstNode::body)
395  .def_ro("annotations", &AllocateConstNode::annotations);
396  }
397 
410  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
411 
412  static constexpr const char* _type_key = "tir.AllocateConst";
414 };
415 
420 class AllocateConst : public Stmt {
421  public:
422  /* The constructor to create a IRNode with constant data
423  * depending on the type of ObjectRef, it will either
424  * create AllocateConstNode with irmod_storage_idx or data
425  */
426  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
427  ObjectRef data_or_idx, Stmt body,
428  Map<String, ffi::Any> annotations = Map<String, ffi::Any>(),
429  Span span = Span());
432 };
433 
435 class DeclBufferNode : public StmtNode {
436  public:
441 
442  static void RegisterReflection() {
443  namespace refl = tvm::ffi::reflection;
444  refl::ObjectDef<DeclBufferNode>()
445  .def_ro("buffer", &DeclBufferNode::buffer)
446  .def_ro("body", &DeclBufferNode::body);
447  }
448 
449  static constexpr const char* _type_key = "tir.DeclBuffer";
451 };
452 
454 class DeclBuffer : public Stmt {
455  public:
456  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
459 };
460 
465 class SeqStmtNode : public StmtNode {
466  public:
468  Array<Stmt> seq;
469 
471  size_t size() const { return seq.size(); }
475  Stmt operator[](size_t index) const { return seq[index]; }
476 
477  static void RegisterReflection() {
478  namespace refl = tvm::ffi::reflection;
479  refl::ObjectDef<SeqStmtNode>().def_ro("seq", &SeqStmtNode::seq);
480  }
481 
482  static constexpr const char* _type_key = "tir.SeqStmt";
484 };
485 
492 class EvaluateNode : public StmtNode {
493  public:
496 
497  static void RegisterReflection() {
498  namespace refl = tvm::ffi::reflection;
499  refl::ObjectDef<EvaluateNode>().def_ro("value", &EvaluateNode::value);
500  }
501 
502  static constexpr const char* _type_key = "tir.Evaluate";
504 };
505 
510 class Evaluate : public Stmt {
511  public:
512  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
513 
514  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
515 
518 };
519 
521 class SeqStmt : public Stmt {
522  public:
528  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
529 
531  size_t size() const { return operator->()->size(); }
535  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
556  template <typename... Args>
557  static Stmt Flatten(Args&&... seq_args) {
558  Array<Stmt> seq;
559 
560  ffi::details::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
561 
562  if (seq.empty()) {
563  return Evaluate(0);
564  } else if (seq.size() == 1) {
565  return seq[0];
566  }
567 
568  // If the argument is a single SeqStmt argument with no
569  // flattening or unwrapping required, then we may
570  // return the SeqStmt as-is.
571  if constexpr (sizeof...(seq_args) == 1) {
572  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
573  SeqStmt original = opt.value();
574  bool all_same = [&]() {
575  if (original->seq.size() != seq.size()) {
576  return false;
577  }
578  for (size_t i = 0; i < seq.size(); i++) {
579  if (!original->seq[i].same_as(seq[i])) {
580  return false;
581  }
582  }
583  return true;
584  }();
585  if (all_same) {
586  return original;
587  }
588  }
589  }
590 
591  return SeqStmt(seq);
592  }
594  class Flattener {
595  public:
596  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
597 
598  template <typename T>
599  static Optional<SeqStmt> AsSeqStmt(const T& t) {
600  if constexpr (std::is_same_v<T, SeqStmt>) {
601  return t;
602  }
603  if constexpr (!std::is_base_of_v<T, SeqStmt>) {
604  return std::nullopt;
605  }
606  if constexpr (std::is_base_of_v<Stmt, T>) {
607  if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) {
608  return GetRef<SeqStmt>(ptr);
609  } else {
610  return std::nullopt;
611  }
612  }
613  return std::nullopt;
614  }
615 
616  template <typename T>
617  void operator()(size_t i, const T& stmt_or_seq) const {
618  if constexpr (std::is_base_of_v<ObjectRef, T>) {
619  // Early bail-out, applicable to any ObjectRef
620  if (!stmt_or_seq.defined()) {
621  return;
622  }
623  }
624 
625  if constexpr (std::is_same_v<T, SeqStmt>) {
626  // Static type-checking for a SeqStmt that could be flattened.
627  (*this)(0, stmt_or_seq->seq);
628  return;
629  }
630 
631  if constexpr (std::is_base_of_v<T, SeqStmt>) {
632  // Dynamic type-checking for a SeqStmt that could be
633  // flattened.
634  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
635  operator()(0, op->seq);
636  return;
637  }
638  }
639 
640  if constexpr (std::is_base_of_v<T, Evaluate>) {
641  // Evaluate(0) is used to represent a no-op, and may be
642  // generated by previous calls to SeqStmt::Flatten(). These
643  // should be removed to ensure that Flatten(a+b) is equivalent
644  // to Flatten(Flatten(a), Flatten(b)).
645  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
646  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
647  return;
648  }
649  }
650  }
651 
652  if constexpr (std::is_base_of_v<Stmt, T>) {
653  // Any other Stmt type just gets appended.
654  seq_->push_back(stmt_or_seq);
655  } else {
656  // Anything else is treated as an iterable of Stmt.
657  for (auto v : stmt_or_seq) {
658  this->operator()(0, v);
659  }
660  }
661  }
662 
663  private:
664  Array<Stmt>* seq_;
665  };
666 
669 };
670 
674 class IfThenElseNode : public StmtNode {
675  public:
681  Optional<Stmt> else_case;
682 
683  static void RegisterReflection() {
684  namespace refl = tvm::ffi::reflection;
685  refl::ObjectDef<IfThenElseNode>()
686  .def_ro("condition", &IfThenElseNode::condition)
687  .def_ro("then_case", &IfThenElseNode::then_case)
688  .def_ro("else_case", &IfThenElseNode::else_case);
689  }
690 
691  static constexpr const char* _type_key = "tir.IfThenElse";
693 };
694 
699 class IfThenElse : public Stmt {
700  public:
701  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = std::nullopt,
702  Span span = Span());
703 
706 };
707 
715 enum class ForKind : int {
717  kSerial = 0,
719  kParallel = 1,
724  kVectorized = 2,
726  kUnrolled = 3,
733  kThreadBinding = 4
734 };
735 
746 class ForNode : public StmtNode {
747  public:
762  Optional<IterVar> thread_binding;
771  Map<String, ffi::Any> annotations;
772 
773  static void RegisterReflection() {
774  namespace refl = tvm::ffi::reflection;
775  refl::ObjectDef<ForNode>()
776  .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef())
777  .def_ro("min", &ForNode::min)
778  .def_ro("extent", &ForNode::extent)
779  .def_ro("kind", &ForNode::kind)
780  .def_ro("body", &ForNode::body)
781  .def_ro("thread_binding", &ForNode::thread_binding)
782  .def_ro("annotations", &ForNode::annotations);
783  }
784 
785  static constexpr const char* _type_key = "tir.For";
787 };
788 
793 class For : public Stmt {
794  public:
795  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
796  Optional<IterVar> thread_binding = std::nullopt,
797  Map<String, ffi::Any> annotations = Map<String, ffi::Any>(), Span span = Span());
798 
801 };
802 
813 class WhileNode : public StmtNode {
814  public:
819 
820  static void RegisterReflection() {
821  namespace refl = tvm::ffi::reflection;
822  refl::ObjectDef<WhileNode>()
823  .def_ro("condition", &WhileNode::condition)
824  .def_ro("body", &WhileNode::body);
825  }
826 
827  static constexpr const char* _type_key = "tir.While";
829 };
830 
835 class While : public Stmt {
836  public:
837  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
838 
841 };
842 
847  public:
851  Array<Range> region;
852 
853  static void RegisterReflection() {
854  namespace refl = tvm::ffi::reflection;
855  refl::ObjectDef<BufferRegionNode>()
856  .def_ro("buffer", &BufferRegionNode::buffer)
857  .def_ro("region", &BufferRegionNode::region);
858  }
859 
860  TVM_DLL PrimExpr ToPrimExpr() const final;
861 
862  static constexpr const char* _type_key = "tir.BufferRegion";
863  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
865 };
866 
872  public:
873  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
874 
881 
888  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
889 
892 };
893 
903 class MatchBufferRegionNode : public Object {
904  public:
909 
910  static void RegisterReflection() {
911  namespace refl = tvm::ffi::reflection;
912  refl::ObjectDef<MatchBufferRegionNode>()
913  .def_ro("buffer", &MatchBufferRegionNode::buffer)
914  .def_ro("source", &MatchBufferRegionNode::source);
915  }
916 
917  static constexpr const char* _type_key = "tir.MatchBufferRegion";
918  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
920 };
921 
926 class MatchBufferRegion : public ObjectRef {
927  public:
928  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
929 
932 };
933 
955 class BlockNode : public StmtNode {
956  public:
958  Array<IterVar> iter_vars;
960  Array<BufferRegion> reads;
962  Array<BufferRegion> writes;
964  String name_hint;
966  Array<Buffer> alloc_buffers;
968  Array<MatchBufferRegion> match_buffers;
970  Map<String, ffi::Any> annotations;
978  Optional<Stmt> init;
981 
982  static void RegisterReflection() {
983  namespace refl = tvm::ffi::reflection;
984  refl::ObjectDef<BlockNode>()
985  .def_ro("iter_vars", &BlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef())
986  .def_ro("reads", &BlockNode::reads)
987  .def_ro("writes", &BlockNode::writes)
988  .def_ro("name_hint", &BlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
989  .def_ro("alloc_buffers", &BlockNode::alloc_buffers)
990  .def_ro("match_buffers", &BlockNode::match_buffers)
991  .def_ro("annotations", &BlockNode::annotations)
992  .def_ro("init", &BlockNode::init)
993  .def_ro("body", &BlockNode::body);
994  }
995 
996  static constexpr const char* _type_key = "tir.Block";
998 };
999 
1004 class Block : public Stmt {
1005  public:
1006  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1007  Array<BufferRegion> writes, String name_hint, Stmt body,
1008  Optional<Stmt> init = std::nullopt,
1009  Array<Buffer> alloc_buffers = Array<Buffer>(),
1010  Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
1011  Map<String, ffi::Any> annotations = Map<String, ffi::Any>(),
1012  Span span = Span());
1013 
1016 };
1017 
1021 class BlockRealizeNode : public StmtNode {
1022  public:
1024  Array<PrimExpr> iter_values;
1032 
1033  static void RegisterReflection() {
1034  namespace refl = tvm::ffi::reflection;
1035  refl::ObjectDef<BlockRealizeNode>()
1036  .def_ro("iter_values", &BlockRealizeNode::iter_values)
1037  .def_ro("predicate", &BlockRealizeNode::predicate)
1038  .def_ro("block", &BlockRealizeNode::block);
1039  }
1040 
1041  static constexpr const char* _type_key = "tir.BlockRealize";
1043 };
1044 
1049 class BlockRealize : public Stmt {
1050  public:
1051  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1052  Span span = Span());
1053 
1056 };
1057 
1059 namespace attr {
1060 // The above attr does not pass to ir stage.
1062 constexpr const char* thread_extent = "thread_extent";
1064 constexpr const char* virtual_thread = "virtual_thread";
1066 constexpr const char* coproc_scope = "coproc_scope";
1071 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1073 constexpr const char* volatile_scope = "volatile_scope";
1079 constexpr const char* extern_scope = "extern_scope";
1084 constexpr const char* compute_scope = "compute_scope";
1086 constexpr const char* storage_alignment = "storage_alignment";
1088 constexpr const char* realize_scope = "realize_scope";
1090 constexpr const char* device_id = "device_id";
1092 constexpr const char* device_type = "device_type";
1094 constexpr const char* loop_scope = "loop_scope";
1096 constexpr const char* reduce_scope = "reduce_scope";
1098 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1100 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1102 constexpr const char* pragma_scope_prefix = "pragma_";
1104 constexpr const char* pragma_import_c = "pragma_import_c";
1106 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1108 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1115 constexpr const char* layout_transforms = "layout_transforms";
1123 constexpr const char* axis_separators = "axis_separators";
1127 constexpr const char* double_buffer_scope = "double_buffer_scope";
1131 constexpr const char* double_buffer_write = "double_buffer_write";
1133 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1135 constexpr const char* scan_update_scope = "scan_update_scope";
1137 constexpr const char* scan_init_scope = "scan_init_scope";
1144 constexpr const char* buffer_dim_align = "buffer_dim_align";
1146 constexpr const char* buffer_bound = "buffer_bound";
1156 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1157 // Pipeline related attributes
1159 constexpr const char* channel_read_scope = "channel_read_scope";
1161 constexpr const char* channel_read_advance = "channel_read_advance";
1163 constexpr const char* channel_write_scope = "channel_write_scope";
1165 constexpr const char* channel_write_advance = "channel_write_advance";
1167 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1169 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1170 
1174 constexpr const char* device_scope = "device_scope";
1175 
1179 constexpr const char* async_scope = "async_scope";
1180 
1198 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1199 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1200 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1201 
1205 constexpr const char* fragment_shape = "fragment_shape";
1206 
1210 constexpr const char* fragment_layout = "fragment_layout";
1211 
1215 constexpr const char* hand_threaded = "hand_threaded";
1216 
1224 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1225 
1229 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1230 
1232 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1233 
1235 constexpr const char* software_pipeline_order = "software_pipeline_order";
1236 
1241 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1242 
1244 constexpr const char* layout_free_buffers = "layout_free_buffers";
1245 
1247 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1248 
1250 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1251 
1256 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1257 
1260  "meta_schedule.thread_extent_low_inclusive";
1261 
1264  "meta_schedule.thread_extent_high_inclusive";
1265 
1268  "meta_schedule.random_compute_producer";
1269 
1271 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1272 
1274 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1275 
1277 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1278 
1280 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1281 
1283 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1284 
1286 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1290 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1291 
1295 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1296 
1298 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1299 
1306 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1307 
1309 constexpr const int meta_schedule_cache_type_read = 0;
1310 
1312 constexpr const int meta_schedule_cache_type_write = 1;
1313 
1315 constexpr const char* auto_copy = "auto_copy";
1316 
1318 constexpr const char* local_stage = "local_stage";
1319 
1321 constexpr const char* vector_bytes = "vector_bytes";
1322 
1327 constexpr const char* warp_execution = "warp_execution";
1328 
1330 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1331 
1335 constexpr const char* explicit_read_region = "explicit_read_region";
1336 
1340 constexpr const char* explicit_write_region = "explicit_write_region";
1341 
1347 inline bool IsPragmaKey(const std::string& attr_key) {
1348  return attr_key.compare(0, 7, "pragma_") == 0;
1349 }
1350 
1351 } // namespace attr
1358 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1359 
1360 // overload printing of for type.
1361 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1362 
1363 // inline implementations
1364 inline const char* ForKind2String(ForKind t) {
1365  switch (t) {
1366  case ForKind::kSerial:
1367  return "serial";
1368  case ForKind::kParallel:
1369  return "parallel";
1370  case ForKind::kVectorized:
1371  return "vectorized";
1372  case ForKind::kUnrolled:
1373  return "unroll";
1375  return "thread_binding";
1376  }
1377  LOG(FATAL) << "Unknown ForKind" << t;
1378 }
1379 
1380 } // namespace tir
1381 } // namespace tvm
1382 #endif // TVM_TIR_STMT_H_
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:159
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:172
Reference to PrimExprNode.
Definition: expr.h:129
Definition: source_map.h:113
Runtime primitive data type.
Definition: data_type.h:47
Allocate a buffer that can be used in body.
Definition: stmt.h:360
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:371
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:375
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:377
DataType dtype
The type of the buffer.
Definition: stmt.h:373
Var buffer_var
The buffer variable.
Definition: stmt.h:363
static constexpr const char * _type_key
Definition: stmt.h:412
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:403
Map< String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:384
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:366
static void RegisterReflection()
Definition: stmt.h:386
Managed reference to AllocateConstNode.
Definition: stmt.h:420
AllocateConst(Var buffer_var, DataType dtype, Array< PrimExpr > extents, ObjectRef data_or_idx, Stmt body, Map< String, ffi::Any > annotations=Map< String, ffi::Any >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode)
Allocate a buffer that can be used in body.
Definition: stmt.h:293
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:300
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
Map< String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:311
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:302
Stmt body
The body to be executed.
Definition: stmt.h:304
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode)
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:329
static void RegisterReflection()
Definition: stmt.h:313
DataType dtype
The type of the buffer.
Definition: stmt.h:298
static constexpr const char * _type_key
Definition: stmt.h:338
Var buffer_var
The buffer variable.
Definition: stmt.h:296
Managed reference to AllocateNode.
Definition: stmt.h:347
Allocate(Var buffer_var, DataType dtype, Array< PrimExpr > extents, PrimExpr condition, Stmt body, Map< String, ffi::Any > annotations=Map< String, ffi::Any >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:154
PrimExpr condition
Condition to be checked.
Definition: stmt.h:157
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:159
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:174
static void RegisterReflection()
Definition: stmt.h:166
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:164
Managed reference to AssertStmtNode.
Definition: stmt.h:182
TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:115
ffi::Any node
this is attribute about certain node
Definition: stmt.h:118
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:122
Stmt body
The body statement to be executed.
Definition: stmt.h:124
String attr_key
the type key of the attribute
Definition: stmt.h:120
static void RegisterReflection()
Definition: stmt.h:126
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:135
Managed reference to AttrStmtNode.
Definition: stmt.h:143
AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:955
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:960
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:968
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:958
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:962
Map< String, ffi::Any > annotations
The annotation of the block.
Definition: stmt.h:970
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:978
String name_hint
The name_hint of the block.
Definition: stmt.h:964
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:966
static void RegisterReflection()
Definition: stmt.h:982
Stmt body
The body of the block.
Definition: stmt.h:980
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode)
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1021
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode)
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1024
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:1029
static void RegisterReflection()
Definition: stmt.h:1033
Block block
The block to be realized.
Definition: stmt.h:1031
Managed reference to BlockRealizeNode.
Definition: stmt.h:1049
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode)
BlockRealize(Array< PrimExpr > iter_values, PrimExpr predicate, Block block, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode)
Managed reference to BlockNode.
Definition: stmt.h:1004
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode)
TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode)
Block(Array< IterVar > iter_vars, Array< BufferRegion > reads, Array< BufferRegion > writes, String name_hint, Stmt body, Optional< Stmt > init=std::nullopt, Array< Buffer > alloc_buffers=Array< Buffer >(), Array< MatchBufferRegion > match_buffers=Array< MatchBufferRegion >(), Map< String, ffi::Any > annotations=Map< String, ffi::Any >(), Span span=Span())
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:248
static void RegisterReflection()
Definition: stmt.h:259
Buffer buffer
The buffer variable.
Definition: stmt.h:251
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:269
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:255
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode)
Stmt body
The body of realization.
Definition: stmt.h:257
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:253
static constexpr const char * _type_key
Definition: stmt.h:273
Managed reference to BufferRealizeNode.
Definition: stmt.h:281
BufferRealize(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:846
PrimExpr ToPrimExpr() const final
static void RegisterReflection()
Definition: stmt.h:853
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:849
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:863
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, PrimExprConvertibleNode)
static constexpr const char * _type_key
Definition: stmt.h:862
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:851
Managed reference to BufferRegionNode.
Definition: stmt.h:871
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
BufferRegion(Buffer buffer, Array< Range > region)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, PrimExprConvertible, BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:200
Buffer buffer
The buffer variable.
Definition: stmt.h:203
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:207
Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:209
PrimExpr value
The value to be stored.
Definition: stmt.h:205
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:220
static void RegisterReflection()
Definition: stmt.h:211
Managed reference to BufferStoreNode.
Definition: stmt.h:228
BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices, Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:157
Declare a buffer that can be used in the body.
Definition: stmt.h:435
static constexpr const char * _type_key
Definition: stmt.h:449
static void RegisterReflection()
Definition: stmt.h:442
Buffer buffer
The buffer being declared.
Definition: stmt.h:438
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:440
Managed reference to DeclBufferNode.
Definition: stmt.h:454
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:492
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:502
static void RegisterReflection()
Definition: stmt.h:497
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:495
Managed reference to EvaluateNode.
Definition: stmt.h:510
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:514
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode)
Evaluate(PrimExpr value, Span span=Span())
A for loop, with possible type annotations.
Definition: stmt.h:746
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:762
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:751
ForKind kind
The kind of the for loop.
Definition: stmt.h:755
static void RegisterReflection()
Definition: stmt.h:773
Map< String, ffi::Any > annotations
Additional annotations about the loop.
Definition: stmt.h:771
static constexpr const char * _type_key
Definition: stmt.h:785
Var loop_var
The loop variable.
Definition: stmt.h:749
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:753
Stmt body
The body of the for loop.
Definition: stmt.h:757
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode)
Managed reference to ForNode.
Definition: stmt.h:793
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional< IterVar > thread_binding=std::nullopt, Map< String, ffi::Any > annotations=Map< String, ffi::Any >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:674
PrimExpr condition
The condition.
Definition: stmt.h:677
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:681
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:683
static constexpr const char * _type_key
Definition: stmt.h:691
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:679
Managed reference to IfThenElseNode.
Definition: stmt.h:699
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
IfThenElse(PrimExpr condition, Stmt then_case, Optional< Stmt > else_case=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode)
Let binding, bind var to value, then run body.
Definition: stmt.h:72
PrimExpr value
The value to be bound.
Definition: stmt.h:77
static constexpr const char * _type_key
Definition: stmt.h:89
Stmt body
The body block.
Definition: stmt.h:79
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:81
Var var
The variable.
Definition: stmt.h:75
Managed reference to LetStmtNode.
Definition: stmt.h:97
TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:903
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object)
Buffer buffer
The target buffer.
Definition: stmt.h:906
static void RegisterReflection()
Definition: stmt.h:910
BufferRegion source
The source buffer region.
Definition: stmt.h:908
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:926
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:465
size_t size() const
Definition: stmt.h:471
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:468
static constexpr const char * _type_key
Definition: stmt.h:482
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:477
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:475
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:594
Flattener(Array< Stmt > *seq)
Definition: stmt.h:596
static Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:599
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:617
Sequence statement.
Definition: stmt.h:521
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:531
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:535
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:557
SeqStmt(Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Base node of all statements.
Definition: stmt.h:38
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:59
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:57
StmtNode(Span span)
Definition: stmt.h:47
static void RegisterReflection()
Definition: stmt.h:49
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
static constexpr const char * _type_key
Definition: stmt.h:56
Container of all statements.
Definition: stmt.h:64
TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:78
A While loop.
Definition: stmt.h:813
static constexpr const char * _type_key
Definition: stmt.h:827
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:818
static void RegisterReflection()
Definition: stmt.h:820
PrimExpr condition
The termination condition.
Definition: stmt.h:816
Managed reference to WhileNode.
Definition: stmt.h:835
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
Definition: repr_printer.h:91
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1084
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1156
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1235
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1277
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1215
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1144
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1161
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1073
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1283
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1167
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1247
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1104
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1100
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1250
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1232
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1267
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1327
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1174
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1347
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1062
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1224
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1298
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1064
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1079
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1306
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1096
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1280
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1163
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1315
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1133
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1198
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1090
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1259
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1286
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1321
constexpr const char * device_type
The device type.
Definition: stmt.h:1092
constexpr const char * explicit_write_region
Mark that a block has an explicitly specified write region. This is used to override the default writ...
Definition: stmt.h:1340
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1309
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1241
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1290
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1256
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1135
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1098
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1094
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1127
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1205
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1108
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1210
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1244
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1199
constexpr const char * explicit_read_region
Mark that a block has an explicitly specified read region. This is used to override the default read ...
Definition: stmt.h:1335
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1312
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:1066
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1146
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1200
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1115
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1123
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1088
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1159
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1263
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1165
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:1071
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1330
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1229
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1169
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1106
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1102
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1137
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1295
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1086
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1318
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1274
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1131
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1179
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1271
const char * ForKind2String(ForKind t)
Definition: stmt.h:1364
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:715
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:237
@ kUnrolled
The execution is unrolled.
Definition: var.h:233
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
TIR expressions.