tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/tir/expr.h>
29 
30 #include <string>
31 #include <type_traits>
32 #include <utility>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
49  static void RegisterReflection() {
50  namespace refl = tvm::ffi::reflection;
51  refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
52  }
53 
55 
56  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
57 
58  static constexpr const uint32_t _type_child_slots = 15;
59  TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object);
60 };
61 
63 class Stmt : public ObjectRef {
64  public:
66 };
67 
71 class LetStmtNode : public StmtNode {
72  public:
79 
80  static void RegisterReflection() {
81  namespace refl = tvm::ffi::reflection;
82  refl::ObjectDef<LetStmtNode>()
83  .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef())
84  .def_ro("value", &LetStmtNode::value)
85  .def_ro("body", &LetStmtNode::body);
86  }
88 };
89 
94 class LetStmt : public Stmt {
95  public:
96  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
97 
100 };
101 
112 class AttrStmtNode : public StmtNode {
113  public:
115  ffi::Any node;
117  ffi::String attr_key;
122 
123  static void RegisterReflection() {
124  namespace refl = tvm::ffi::reflection;
125  refl::ObjectDef<AttrStmtNode>()
126  .def_ro("node", &AttrStmtNode::node)
127  .def_ro("attr_key", &AttrStmtNode::attr_key)
128  .def_ro("value", &AttrStmtNode::value)
129  .def_ro("body", &AttrStmtNode::body);
130  }
132 };
133 
138 class AttrStmt : public Stmt {
139  public:
140  TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body,
141  Span span = Span());
142 
145 };
146 
150 class AssertStmtNode : public StmtNode {
151  public:
161 
162  static void RegisterReflection() {
163  namespace refl = tvm::ffi::reflection;
164  refl::ObjectDef<AssertStmtNode>()
165  .def_ro("condition", &AssertStmtNode::condition)
166  .def_ro("message", &AssertStmtNode::message)
167  .def_ro("body", &AssertStmtNode::body);
168  }
170 };
171 
176 class AssertStmt : public Stmt {
177  public:
178  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
179 
182 };
183 
194 class BufferStoreNode : public StmtNode {
195  public:
201  ffi::Array<PrimExpr> indices;
203  ffi::Optional<PrimExpr> predicate;
204 
205  static void RegisterReflection() {
206  namespace refl = tvm::ffi::reflection;
207  refl::ObjectDef<BufferStoreNode>()
208  .def_ro("buffer", &BufferStoreNode::buffer)
209  .def_ro("value", &BufferStoreNode::value)
210  .def_ro("indices", &BufferStoreNode::indices)
211  .def_ro("predicate", &BufferStoreNode::predicate);
212  }
214 };
215 
220 class BufferStore : public Stmt {
221  public:
222  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
223  ffi::Optional<PrimExpr> predicate = std::nullopt,
224  Span span = Span());
225 
228 };
229 
241 class BufferRealizeNode : public StmtNode {
242  public:
246  ffi::Array<Range> bounds;
251 
252  static void RegisterReflection() {
253  namespace refl = tvm::ffi::reflection;
254  refl::ObjectDef<BufferRealizeNode>()
255  .def_ro("buffer", &BufferRealizeNode::buffer)
256  .def_ro("bounds", &BufferRealizeNode::bounds)
257  .def_ro("condition", &BufferRealizeNode::condition)
258  .def_ro("body", &BufferRealizeNode::body);
259  }
260 
261  BufferRealizeNode() = default;
263  Span span = Span())
266 };
267 
272 class BufferRealize : public Stmt {
273  public:
274  TVM_DLL explicit BufferRealize(Buffer buffer, ffi::Array<Range> bounds, PrimExpr condition,
275  Stmt body, Span span = Span());
276 
279 };
280 
284 class AllocateNode : public StmtNode {
285  public:
291  ffi::Array<PrimExpr> extents;
302  ffi::Map<ffi::String, ffi::Any> annotations;
303 
304  static void RegisterReflection() {
305  namespace refl = tvm::ffi::reflection;
306  refl::ObjectDef<AllocateNode>()
307  .def_ro("buffer_var", &AllocateNode::buffer_var, refl::AttachFieldFlag::SEqHashDef())
308  .def_ro("dtype", &AllocateNode::dtype)
309  .def_ro("extents", &AllocateNode::extents)
310  .def_ro("condition", &AllocateNode::condition)
311  .def_ro("body", &AllocateNode::body)
312  .def_ro("annotations", &AllocateNode::annotations);
313  }
314 
327  TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array<PrimExpr>& extents);
329 };
330 
335 class Allocate : public Stmt {
336  public:
337  TVM_DLL Allocate(Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, PrimExpr condition,
338  Stmt body,
339  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
340  Span span = Span());
341 
344 };
345 
349 class AllocateConstNode : public StmtNode {
350  public:
355  ffi::Optional<runtime::Tensor> data;
360  ffi::Optional<Integer> irmod_storage_idx;
364  ffi::Array<PrimExpr> extents;
373  ffi::Map<ffi::String, ffi::Any> annotations;
374 
375  static void RegisterReflection() {
376  namespace refl = tvm::ffi::reflection;
377  refl::ObjectDef<AllocateConstNode>()
378  .def_ro("buffer_var", &AllocateConstNode::buffer_var, refl::AttachFieldFlag::SEqHashDef())
379  .def_ro("data", &AllocateConstNode::data)
380  .def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx)
381  .def_ro("dtype", &AllocateConstNode::dtype)
382  .def_ro("extents", &AllocateConstNode::extents)
383  .def_ro("body", &AllocateConstNode::body)
384  .def_ro("annotations", &AllocateConstNode::annotations);
385  }
386 
399  TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array<PrimExpr>& extents);
401 };
402 
407 class AllocateConst : public Stmt {
408  public:
409  /* The constructor to create a IRNode with constant data
410  * depending on the type of ObjectRef, it will either
411  * create AllocateConstNode with irmod_storage_idx or data
412  */
413  TVM_DLL AllocateConst(
414  Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, ObjectRef data_or_idx,
415  Stmt body, ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
416  Span span = Span());
419 };
420 
422 class DeclBufferNode : public StmtNode {
423  public:
428 
429  static void RegisterReflection() {
430  namespace refl = tvm::ffi::reflection;
431  refl::ObjectDef<DeclBufferNode>()
432  .def_ro("buffer", &DeclBufferNode::buffer)
433  .def_ro("body", &DeclBufferNode::body);
434  }
436 };
437 
439 class DeclBuffer : public Stmt {
440  public:
441  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
444 };
445 
450 class SeqStmtNode : public StmtNode {
451  public:
453  ffi::Array<Stmt> seq;
454 
456  size_t size() const { return seq.size(); }
460  Stmt operator[](size_t index) const { return seq[index]; }
461 
462  static void RegisterReflection() {
463  namespace refl = tvm::ffi::reflection;
464  refl::ObjectDef<SeqStmtNode>().def_ro("seq", &SeqStmtNode::seq);
465  }
467 };
468 
475 class EvaluateNode : public StmtNode {
476  public:
479 
480  static void RegisterReflection() {
481  namespace refl = tvm::ffi::reflection;
482  refl::ObjectDef<EvaluateNode>().def_ro("value", &EvaluateNode::value);
483  }
485 };
486 
491 class Evaluate : public Stmt {
492  public:
493  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
494 
495  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
496 
499 };
500 
502 class SeqStmt : public Stmt {
503  public:
509  TVM_DLL explicit SeqStmt(ffi::Array<Stmt> seq, Span span = Span());
510 
512  size_t size() const { return operator->()->size(); }
516  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
537  template <typename... Args>
538  static Stmt Flatten(Args&&... seq_args) {
539  ffi::Array<Stmt> seq;
540 
541  ffi::details::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
542 
543  if (seq.empty()) {
544  return Evaluate(0);
545  } else if (seq.size() == 1) {
546  return seq[0];
547  }
548 
549  // If the argument is a single SeqStmt argument with no
550  // flattening or unwrapping required, then we may
551  // return the SeqStmt as-is.
552  if constexpr (sizeof...(seq_args) == 1) {
553  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
554  SeqStmt original = opt.value();
555  bool all_same = [&]() {
556  if (original->seq.size() != seq.size()) {
557  return false;
558  }
559  for (size_t i = 0; i < seq.size(); i++) {
560  if (!original->seq[i].same_as(seq[i])) {
561  return false;
562  }
563  }
564  return true;
565  }();
566  if (all_same) {
567  return original;
568  }
569  }
570  }
571 
572  return SeqStmt(seq);
573  }
575  class Flattener {
576  public:
577  explicit Flattener(ffi::Array<Stmt>* seq) : seq_(seq) {}
578 
579  template <typename T>
580  static ffi::Optional<SeqStmt> AsSeqStmt(const T& t) {
581  if constexpr (std::is_same_v<T, SeqStmt>) {
582  return t;
583  }
584  if constexpr (!std::is_base_of_v<T, SeqStmt>) {
585  return std::nullopt;
586  }
587  if constexpr (std::is_base_of_v<Stmt, T>) {
588  if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) {
589  return ffi::GetRef<SeqStmt>(ptr);
590  } else {
591  return std::nullopt;
592  }
593  }
594  return std::nullopt;
595  }
596 
597  template <typename T>
598  void operator()(size_t i, const T& stmt_or_seq) const {
599  if constexpr (std::is_base_of_v<ObjectRef, T>) {
600  // Early bail-out, applicable to any ObjectRef
601  if (!stmt_or_seq.defined()) {
602  return;
603  }
604  }
605 
606  if constexpr (std::is_same_v<T, SeqStmt>) {
607  // Static type-checking for a SeqStmt that could be flattened.
608  (*this)(0, stmt_or_seq->seq);
609  return;
610  }
611 
612  if constexpr (std::is_base_of_v<T, SeqStmt>) {
613  // Dynamic type-checking for a SeqStmt that could be
614  // flattened.
615  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
616  operator()(0, op->seq);
617  return;
618  }
619  }
620 
621  if constexpr (std::is_base_of_v<T, Evaluate>) {
622  // Evaluate(0) is used to represent a no-op, and may be
623  // generated by previous calls to SeqStmt::Flatten(). These
624  // should be removed to ensure that Flatten(a+b) is equivalent
625  // to Flatten(Flatten(a), Flatten(b)).
626  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
627  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
628  return;
629  }
630  }
631  }
632 
633  if constexpr (std::is_base_of_v<Stmt, T>) {
634  // Any other Stmt type just gets appended.
635  seq_->push_back(stmt_or_seq);
636  } else {
637  // Anything else is treated as an iterable of Stmt.
638  for (auto v : stmt_or_seq) {
639  this->operator()(0, v);
640  }
641  }
642  }
643 
644  private:
645  ffi::Array<Stmt>* seq_;
646  };
647 
650 };
651 
655 class IfThenElseNode : public StmtNode {
656  public:
662  ffi::Optional<Stmt> else_case;
663 
664  static void RegisterReflection() {
665  namespace refl = tvm::ffi::reflection;
666  refl::ObjectDef<IfThenElseNode>()
667  .def_ro("condition", &IfThenElseNode::condition)
668  .def_ro("then_case", &IfThenElseNode::then_case)
669  .def_ro("else_case", &IfThenElseNode::else_case);
670  }
672 };
673 
678 class IfThenElse : public Stmt {
679  public:
680  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case,
681  ffi::Optional<Stmt> else_case = std::nullopt, Span span = Span());
682 
685 };
686 
694 enum class ForKind : int {
696  kSerial = 0,
698  kParallel = 1,
703  kVectorized = 2,
705  kUnrolled = 3,
712  kThreadBinding = 4
713 };
714 
725 class ForNode : public StmtNode {
726  public:
741  ffi::Optional<IterVar> thread_binding;
750  ffi::Map<ffi::String, ffi::Any> annotations;
754  ffi::Optional<PrimExpr> step;
755 
756  static void RegisterReflection() {
757  namespace refl = tvm::ffi::reflection;
758  refl::ObjectDef<ForNode>()
759  .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef())
760  .def_ro("min", &ForNode::min)
761  .def_ro("extent", &ForNode::extent)
762  .def_ro("kind", &ForNode::kind)
763  .def_ro("body", &ForNode::body)
764  .def_ro("thread_binding", &ForNode::thread_binding)
765  .def_ro("annotations", &ForNode::annotations)
766  .def_ro("step", &ForNode::step);
767  }
768 
770  bool HasTrivialStep() const;
771 
773 };
774 
779 class For : public Stmt {
780  public:
781  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
782  ffi::Optional<IterVar> thread_binding = std::nullopt,
783  ffi::Map<ffi::String, ffi::Any> annotations = {},
784  ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
785 
788 };
789 
800 class WhileNode : public StmtNode {
801  public:
806 
807  static void RegisterReflection() {
808  namespace refl = tvm::ffi::reflection;
809  refl::ObjectDef<WhileNode>()
810  .def_ro("condition", &WhileNode::condition)
811  .def_ro("body", &WhileNode::body);
812  }
814 };
815 
820 class While : public Stmt {
821  public:
822  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
823 
826 };
827 
832  public:
836  ffi::Array<Range> region;
837 
838  static void RegisterReflection() {
839  namespace refl = tvm::ffi::reflection;
840  refl::ObjectDef<BufferRegionNode>()
841  .def_ro("buffer", &BufferRegionNode::buffer)
842  .def_ro("region", &BufferRegionNode::region);
843  }
844 
845  TVM_DLL PrimExpr ToPrimExpr() const final;
846 
847  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
849 };
850 
856  public:
857  TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array<Range> region);
858 
865 
872  TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array<PrimExpr> indices);
873 
876 };
877 
887 class MatchBufferRegionNode : public Object {
888  public:
893 
894  static void RegisterReflection() {
895  namespace refl = tvm::ffi::reflection;
896  refl::ObjectDef<MatchBufferRegionNode>()
897  .def_ro("buffer", &MatchBufferRegionNode::buffer)
898  .def_ro("source", &MatchBufferRegionNode::source);
899  }
900 
901  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
902  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object);
903 };
904 
909 class MatchBufferRegion : public ObjectRef {
910  public:
911  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
912 
915 };
916 
938 class BlockNode : public StmtNode {
939  public:
941  ffi::Array<IterVar> iter_vars;
943  ffi::Array<BufferRegion> reads;
945  ffi::Array<BufferRegion> writes;
947  ffi::String name_hint;
949  ffi::Array<Buffer> alloc_buffers;
951  ffi::Array<MatchBufferRegion> match_buffers;
953  ffi::Map<ffi::String, ffi::Any> annotations;
961  ffi::Optional<Stmt> init;
964 
965  static void RegisterReflection() {
966  namespace refl = tvm::ffi::reflection;
967  refl::ObjectDef<BlockNode>()
968  .def_ro("iter_vars", &BlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef())
969  .def_ro("reads", &BlockNode::reads)
970  .def_ro("writes", &BlockNode::writes)
971  .def_ro("name_hint", &BlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
972  .def_ro("alloc_buffers", &BlockNode::alloc_buffers)
973  .def_ro("match_buffers", &BlockNode::match_buffers)
974  .def_ro("annotations", &BlockNode::annotations)
975  .def_ro("init", &BlockNode::init)
976  .def_ro("body", &BlockNode::body);
977  }
979 };
980 
985 class Block : public Stmt {
986  public:
987  TVM_DLL explicit Block(
988  ffi::Array<IterVar> iter_vars, ffi::Array<BufferRegion> reads,
989  ffi::Array<BufferRegion> writes, ffi::String name_hint, Stmt body,
990  ffi::Optional<Stmt> init = std::nullopt,
991  ffi::Array<Buffer> alloc_buffers = ffi::Array<Buffer>(),
992  ffi::Array<MatchBufferRegion> match_buffers = ffi::Array<MatchBufferRegion>(),
993  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
994  Span span = Span());
995 
998 };
999 
1003 class BlockRealizeNode : public StmtNode {
1004  public:
1006  ffi::Array<PrimExpr> iter_values;
1014 
1015  static void RegisterReflection() {
1016  namespace refl = tvm::ffi::reflection;
1017  refl::ObjectDef<BlockRealizeNode>()
1018  .def_ro("iter_values", &BlockRealizeNode::iter_values)
1019  .def_ro("predicate", &BlockRealizeNode::predicate)
1020  .def_ro("block", &BlockRealizeNode::block);
1021  }
1023 };
1024 
1029 class BlockRealize : public Stmt {
1030  public:
1031  TVM_DLL explicit BlockRealize(ffi::Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1032  Span span = Span());
1033 
1036 };
1037 
1039 namespace attr {
1040 // The above attr does not pass to ir stage.
1042 constexpr const char* thread_extent = "thread_extent";
1044 constexpr const char* virtual_thread = "virtual_thread";
1046 constexpr const char* coproc_scope = "coproc_scope";
1051 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1053 constexpr const char* volatile_scope = "volatile_scope";
1059 constexpr const char* extern_scope = "extern_scope";
1064 constexpr const char* compute_scope = "compute_scope";
1066 constexpr const char* storage_alignment = "storage_alignment";
1068 constexpr const char* realize_scope = "realize_scope";
1070 constexpr const char* device_id = "device_id";
1072 constexpr const char* device_type = "device_type";
1074 constexpr const char* loop_scope = "loop_scope";
1076 constexpr const char* reduce_scope = "reduce_scope";
1078 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1080 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1082 constexpr const char* pragma_scope_prefix = "pragma_";
1084 constexpr const char* pragma_import_c = "pragma_import_c";
1086 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1088 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1095 constexpr const char* layout_transforms = "layout_transforms";
1103 constexpr const char* axis_separators = "axis_separators";
1107 constexpr const char* double_buffer_scope = "double_buffer_scope";
1111 constexpr const char* double_buffer_write = "double_buffer_write";
1113 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1115 constexpr const char* scan_update_scope = "scan_update_scope";
1117 constexpr const char* scan_init_scope = "scan_init_scope";
1124 constexpr const char* buffer_dim_align = "buffer_dim_align";
1126 constexpr const char* buffer_bound = "buffer_bound";
1136 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1137 // Pipeline related attributes
1139 constexpr const char* channel_read_scope = "channel_read_scope";
1141 constexpr const char* channel_read_advance = "channel_read_advance";
1143 constexpr const char* channel_write_scope = "channel_write_scope";
1145 constexpr const char* channel_write_advance = "channel_write_advance";
1147 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1149 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1150 
1154 constexpr const char* device_scope = "device_scope";
1155 
1159 constexpr const char* async_scope = "async_scope";
1160 
1178 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1179 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1180 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1181 
1185 constexpr const char* fragment_shape = "fragment_shape";
1186 
1190 constexpr const char* fragment_layout = "fragment_layout";
1191 
1195 constexpr const char* hand_threaded = "hand_threaded";
1196 
1204 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1205 
1209 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1210 
1212 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1213 
1215 constexpr const char* software_pipeline_order = "software_pipeline_order";
1216 
1221 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1222 
1224 constexpr const char* layout_free_buffers = "layout_free_buffers";
1225 
1227 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1228 
1230 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1231 
1236 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1237 
1240  "meta_schedule.thread_extent_low_inclusive";
1241 
1244  "meta_schedule.thread_extent_high_inclusive";
1245 
1248  "meta_schedule.random_compute_producer";
1249 
1251 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1252 
1254 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1255 
1257 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1258 
1260 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1261 
1263 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1264 
1266 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1270 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1271 
1275 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1276 
1278 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1279 
1286 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1287 
1289 constexpr const int meta_schedule_cache_type_read = 0;
1290 
1292 constexpr const int meta_schedule_cache_type_write = 1;
1293 
1295 constexpr const char* auto_copy = "auto_copy";
1296 
1298 constexpr const char* local_stage = "local_stage";
1299 
1301 constexpr const char* vector_bytes = "vector_bytes";
1302 
1307 constexpr const char* warp_execution = "warp_execution";
1308 
1310 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1311 
1315 constexpr const char* explicit_read_region = "explicit_read_region";
1316 
1320 constexpr const char* explicit_write_region = "explicit_write_region";
1321 
1323 constexpr const char* irregular_loop_mark = "irregular_loop_mark";
1324 
1330 inline bool IsPragmaKey(const std::string& attr_key) {
1331  return attr_key.compare(0, 7, "pragma_") == 0;
1332 }
1333 
1334 } // namespace attr
1341 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1342 
1343 // overload printing of for type.
1344 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1345 
1346 // inline implementations
1347 inline const char* ForKind2String(ForKind t) {
1348  switch (t) {
1349  case ForKind::kSerial:
1350  return "serial";
1351  case ForKind::kParallel:
1352  return "parallel";
1353  case ForKind::kVectorized:
1354  return "vectorized";
1355  case ForKind::kUnrolled:
1356  return "unroll";
1358  return "thread_binding";
1359  }
1360  LOG(FATAL) << "Unknown ForKind" << t;
1361 }
1362 
1363 } // namespace tir
1364 } // namespace tvm
1365 #endif // TVM_TIR_STMT_H_
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:154
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:165
Reference to PrimExprNode.
Definition: expr.h:124
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
Allocate a buffer that can be used in body.
Definition: stmt.h:349
static int64_t ConstantAllocationSize(const ffi::Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AllocateConst", AllocateConstNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:366
DataType dtype
The type of the buffer.
Definition: stmt.h:362
Var buffer_var
The buffer variable.
Definition: stmt.h:352
ffi::Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:364
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:373
ffi::Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:360
ffi::Optional< runtime::Tensor > data
The optional data associated to the constant.
Definition: stmt.h:355
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:392
static void RegisterReflection()
Definition: stmt.h:375
Managed reference to AllocateConstNode.
Definition: stmt.h:407
AllocateConst(Var buffer_var, DataType dtype, ffi::Array< PrimExpr > extents, ObjectRef data_or_idx, Stmt body, ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllocateConst, Stmt, AllocateConstNode)
Allocate a buffer that can be used in body.
Definition: stmt.h:284
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Allocate", AllocateNode, StmtNode)
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:302
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:293
Stmt body
The body to be executed.
Definition: stmt.h:295
static int64_t ConstantAllocationSize(const ffi::Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:320
static void RegisterReflection()
Definition: stmt.h:304
DataType dtype
The type of the buffer.
Definition: stmt.h:289
Var buffer_var
The buffer variable.
Definition: stmt.h:287
ffi::Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:291
Managed reference to AllocateNode.
Definition: stmt.h:335
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Allocate, Stmt, AllocateNode)
Allocate(Var buffer_var, DataType dtype, ffi::Array< PrimExpr > extents, PrimExpr condition, Stmt body, ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:150
PrimExpr condition
Condition to be checked.
Definition: stmt.h:153
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:155
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AssertStmt", AssertStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:162
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:160
Managed reference to AssertStmtNode.
Definition: stmt.h:176
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:112
ffi::Any node
this is attribute about certain node
Definition: stmt.h:115
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:119
Stmt body
The body statement to be executed.
Definition: stmt.h:121
static void RegisterReflection()
Definition: stmt.h:123
ffi::String attr_key
the type key of the attribute
Definition: stmt.h:117
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AttrStmt", AttrStmtNode, StmtNode)
Managed reference to AttrStmtNode.
Definition: stmt.h:138
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:938
ffi::String name_hint
The name_hint of the block.
Definition: stmt.h:947
ffi::Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:951
ffi::Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:945
ffi::Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:961
ffi::Map< ffi::String, ffi::Any > annotations
The annotation of the block.
Definition: stmt.h:953
ffi::Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:941
static void RegisterReflection()
Definition: stmt.h:965
ffi::Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:949
Stmt body
The body of the block.
Definition: stmt.h:963
ffi::Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:943
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Block", BlockNode, StmtNode)
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1003
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRealize", BlockRealizeNode, StmtNode)
ffi::Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1006
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:1011
static void RegisterReflection()
Definition: stmt.h:1015
Block block
The block to be realized.
Definition: stmt.h:1013
Managed reference to BlockRealizeNode.
Definition: stmt.h:1029
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode)
BlockRealize(ffi::Array< PrimExpr > iter_values, PrimExpr predicate, Block block, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BlockRealize, Stmt, BlockRealizeNode)
Managed reference to BlockNode.
Definition: stmt.h:985
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode)
Block(ffi::Array< IterVar > iter_vars, ffi::Array< BufferRegion > reads, ffi::Array< BufferRegion > writes, ffi::String name_hint, Stmt body, ffi::Optional< Stmt > init=std::nullopt, ffi::Array< Buffer > alloc_buffers=ffi::Array< Buffer >(), ffi::Array< MatchBufferRegion > match_buffers=ffi::Array< MatchBufferRegion >(), ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Block, Stmt, BlockNode)
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:241
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRealize", BufferRealizeNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:252
Buffer buffer
The buffer variable.
Definition: stmt.h:244
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:248
Stmt body
The body of realization.
Definition: stmt.h:250
ffi::Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:246
BufferRealizeNode(Buffer buffer, ffi::Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:262
Managed reference to BufferRealizeNode.
Definition: stmt.h:272
BufferRealize(Buffer buffer, ffi::Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BufferRealize, Stmt, BufferRealizeNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:831
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRegion", BufferRegionNode, PrimExprConvertibleNode)
PrimExpr ToPrimExpr() const final
static void RegisterReflection()
Definition: stmt.h:838
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:834
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:847
ffi::Array< Range > region
The region array of the buffer region.
Definition: stmt.h:836
Managed reference to BufferRegionNode.
Definition: stmt.h:855
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, ffi::Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferRegion, PrimExprConvertible, BufferRegionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
BufferRegion(Buffer buffer, ffi::Array< Range > region)
Store value to the high dimension buffer.
Definition: stmt.h:194
Buffer buffer
The buffer variable.
Definition: stmt.h:197
ffi::Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:201
PrimExpr value
The value to be stored.
Definition: stmt.h:199
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferStore", BufferStoreNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:205
ffi::Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:203
Managed reference to BufferStoreNode.
Definition: stmt.h:220
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferStore, Stmt, BufferStoreNode)
BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Declare a buffer that can be used in the body.
Definition: stmt.h:422
static void RegisterReflection()
Definition: stmt.h:429
Buffer buffer
The buffer being declared.
Definition: stmt.h:425
Stmt body
The body to be executed.
Definition: stmt.h:427
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.DeclBuffer", DeclBufferNode, StmtNode)
Managed reference to DeclBufferNode.
Definition: stmt.h:439
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:475
static void RegisterReflection()
Definition: stmt.h:480
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:478
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Evaluate", EvaluateNode, StmtNode)
Managed reference to EvaluateNode.
Definition: stmt.h:491
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:495
Evaluate(PrimExpr value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Evaluate, Stmt, EvaluateNode)
A for loop, with possible type annotations.
Definition: stmt.h:725
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:730
ForKind kind
The kind of the for loop.
Definition: stmt.h:734
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:756
Var loop_var
The loop variable.
Definition: stmt.h:728
bool HasTrivialStep() const
Check it is a loop without nontrivial loop step.
ffi::Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:741
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the loop.
Definition: stmt.h:750
ffi::Optional< PrimExpr > step
The loop step. It is one if not specified.
Definition: stmt.h:754
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:732
Stmt body
The body of the for loop.
Definition: stmt.h:736
Managed reference to ForNode.
Definition: stmt.h:779
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional< IterVar > thread_binding=std::nullopt, ffi::Map< ffi::String, ffi::Any > annotations={}, ffi::Optional< PrimExpr > step=std::nullopt, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:655
PrimExpr condition
The condition.
Definition: stmt.h:658
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IfThenElse", IfThenElseNode, StmtNode)
ffi::Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:662
static void RegisterReflection()
Definition: stmt.h:664
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:660
Managed reference to IfThenElseNode.
Definition: stmt.h:678
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IfThenElse, Stmt, IfThenElseNode)
IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional< Stmt > else_case=std::nullopt, Span span=Span())
Let binding, bind var to value, then run body.
Definition: stmt.h:71
PrimExpr value
The value to be bound.
Definition: stmt.h:76
Stmt body
The body block.
Definition: stmt.h:78
static void RegisterReflection()
Definition: stmt.h:80
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode)
Var var
The variable.
Definition: stmt.h:74
Managed reference to LetStmtNode.
Definition: stmt.h:94
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode)
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:887
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object)
Buffer buffer
The target buffer.
Definition: stmt.h:890
static void RegisterReflection()
Definition: stmt.h:894
BufferRegion source
The source buffer region.
Definition: stmt.h:892
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:909
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:450
size_t size() const
Definition: stmt.h:456
ffi::Array< Stmt > seq
internal sequence content.
Definition: stmt.h:453
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SeqStmt", SeqStmtNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:462
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:460
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:575
Flattener(ffi::Array< Stmt > *seq)
Definition: stmt.h:577
static ffi::Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:580
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:598
Sequence statement.
Definition: stmt.h:502
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:512
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
SeqStmt(ffi::Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:516
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:538
Base node of all statements.
Definition: stmt.h:38
TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object)
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:58
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:56
StmtNode(Span span)
Definition: stmt.h:47
static void RegisterReflection()
Definition: stmt.h:49
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
Container of all statements.
Definition: stmt.h:63
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:77
A While loop.
Definition: stmt.h:800
Stmt body
The body of the while loop.
Definition: stmt.h:805
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.While", WhileNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:807
PrimExpr condition
The termination condition.
Definition: stmt.h:803
Managed reference to WhileNode.
Definition: stmt.h:820
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(While, Stmt, WhileNode)
Definition: repr_printer.h:91
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1064
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1136
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1215
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1257
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1195
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1124
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1141
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1053
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1263
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1147
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1227
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1084
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1080
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1230
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1212
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1247
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1307
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1154
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1330
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1042
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1204
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1278
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1044
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1059
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1286
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1076
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1260
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1143
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1295
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1113
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1178
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1070
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1239
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1266
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1301
constexpr const char * device_type
The device type.
Definition: stmt.h:1072
constexpr const char * explicit_write_region
Mark that a block has an explicitly specified write region. This is used to override the default writ...
Definition: stmt.h:1320
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1289
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1221
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1270
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1236
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1115
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1078
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1074
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1107
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1185
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1088
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1190
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1224
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1179
constexpr const char * explicit_read_region
Mark that a block has an explicitly specified read region. This is used to override the default read ...
Definition: stmt.h:1315
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1292
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:1046
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1126
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1180
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1095
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1103
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1068
constexpr const char * irregular_loop_mark
,ark a ForNode represent an irregular loop of non-structural control flow edges.
Definition: stmt.h:1323
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1139
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1243
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1145
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:1051
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1310
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1209
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1149
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1086
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1082
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1117
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1275
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1066
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1298
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1254
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1111
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1159
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1251
const char * ForKind2String(ForKind t)
Definition: stmt.h:1347
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:694
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:237
@ kUnrolled
The execution is unrolled.
Definition: var.h:233
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
TIR expressions.