tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIRX_STMT_H_
25 #define TVM_TIRX_STMT_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tirx/exec_scope.h>
30 #include <tvm/tirx/expr.h>
31 #include <tvm/tirx/layout.h>
32 
33 #include <optional>
34 #include <string>
35 #include <type_traits>
36 #include <utility>
37 
38 namespace tvm {
39 namespace tirx {
40 
42 class StmtNode : public ffi::Object {
43  public:
48  mutable Span span;
49 
50  StmtNode() = default;
51  explicit StmtNode(Span span) : span(span) {}
52 
53  static void RegisterReflection() {
54  namespace refl = tvm::ffi::reflection;
55  refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
56  }
57 
59 
60  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
61 
62  static constexpr const uint32_t _type_child_slots = 15;
63  TVM_FFI_DECLARE_OBJECT_INFO("tirx.Stmt", StmtNode, ffi::Object);
64 };
65 
67 class Stmt : public ffi::ObjectRef {
68  public:
70 };
71 
79 class BindNode : public StmtNode {
80  public:
85 
86  static void RegisterReflection() {
87  namespace refl = tvm::ffi::reflection;
88  refl::ObjectDef<BindNode>()
89  .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef())
90  .def_ro("value", &BindNode::value);
91  }
93 };
94 
99 class Bind : public Stmt {
100  public:
101  TVM_DLL Bind(Var var, PrimExpr value, Span span = Span());
102 
105 };
106 
117 class AttrStmtNode : public StmtNode {
118  public:
120  ffi::Any node;
122  ffi::String attr_key;
127 
128  static void RegisterReflection() {
129  namespace refl = tvm::ffi::reflection;
130  refl::ObjectDef<AttrStmtNode>()
131  .def_ro("node", &AttrStmtNode::node)
132  .def_ro("attr_key", &AttrStmtNode::attr_key)
133  .def_ro("value", &AttrStmtNode::value)
134  .def_ro("body", &AttrStmtNode::body);
135  }
137 };
138 
143 class AttrStmt : public Stmt {
144  public:
145  TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body,
146  Span span = Span());
147 
150 };
151 
161 class AssertStmtNode : public StmtNode {
162  public:
168  ffi::Array<StringImm> message_parts;
169 
170  static void RegisterReflection() {
171  namespace refl = tvm::ffi::reflection;
172  refl::ObjectDef<AssertStmtNode>()
173  .def_ro("condition", &AssertStmtNode::condition)
174  .def_ro("error_kind", &AssertStmtNode::error_kind)
175  .def_ro("message_parts", &AssertStmtNode::message_parts);
176  }
178 };
179 
184 class AssertStmt : public Stmt {
185  public:
186  TVM_DLL AssertStmt(PrimExpr condition, StringImm error_kind, ffi::Array<StringImm> message_parts,
187  Span span = Span());
188 
191 };
192 
203 class BufferStoreNode : public StmtNode {
204  public:
210  ffi::Array<PrimExpr> indices;
212  ffi::Optional<PrimExpr> predicate;
213 
214  static void RegisterReflection() {
215  namespace refl = tvm::ffi::reflection;
216  refl::ObjectDef<BufferStoreNode>()
217  .def_ro("buffer", &BufferStoreNode::buffer)
218  .def_ro("value", &BufferStoreNode::value)
219  .def_ro("indices", &BufferStoreNode::indices)
220  .def_ro("predicate", &BufferStoreNode::predicate);
221  }
223 };
224 
229 class BufferStore : public Stmt {
230  public:
231  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
232  ffi::Optional<PrimExpr> predicate = std::nullopt,
233  Span span = Span());
234 
237 };
238 
240 class DeclBufferNode : public StmtNode {
241  public:
244 
245  static void RegisterReflection() {
246  namespace refl = tvm::ffi::reflection;
247  refl::ObjectDef<DeclBufferNode>().def_ro("buffer", &DeclBufferNode::buffer);
248  }
250 };
251 
253 class DeclBuffer : public Stmt {
254  public:
255  TVM_DLL DeclBuffer(Buffer buffer, Span span = Span());
258 };
259 
261 class AllocBufferNode : public StmtNode {
262  public:
271  ffi::Map<ffi::String, ffi::Any> annotations;
272 
273  static void RegisterReflection() {
274  namespace refl = tvm::ffi::reflection;
275  refl::ObjectDef<AllocBufferNode>()
276  .def_ro("buffer", &AllocBufferNode::buffer, refl::AttachFieldFlag::SEqHashDef())
277  .def_ro("annotations", &AllocBufferNode::annotations);
278  }
280 };
281 
283 class AllocBuffer : public Stmt {
284  public:
285  TVM_DLL AllocBuffer(
286  Buffer buffer,
287  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
288  Span span = Span());
293  std::optional<int64_t> ConstantAllocationSize() const {
294  int64_t result = 1;
295  for (const PrimExpr& extent : (*this)->buffer->shape) {
296  if (const auto* int_size = extent.as<IntImmNode>()) {
297  result *= int_size->value;
298  } else {
299  return std::nullopt;
300  }
301  }
302  return result;
303  }
304 
307 };
308 
313 class SeqStmtNode : public StmtNode {
314  public:
316  ffi::Array<Stmt> seq;
317 
319  size_t size() const { return seq.size(); }
323  Stmt operator[](size_t index) const { return seq[index]; }
324 
325  static void RegisterReflection() {
326  namespace refl = tvm::ffi::reflection;
327  refl::ObjectDef<SeqStmtNode>().def_ro("seq", &SeqStmtNode::seq);
328  }
330 };
331 
338 class EvaluateNode : public StmtNode {
339  public:
342 
343  static void RegisterReflection() {
344  namespace refl = tvm::ffi::reflection;
345  refl::ObjectDef<EvaluateNode>().def_ro("value", &EvaluateNode::value);
346  }
348 };
349 
354 class Evaluate : public Stmt {
355  public:
356  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
357 
358  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
359 
362 };
363 
365 class SeqStmt : public Stmt {
366  public:
372  TVM_DLL explicit SeqStmt(ffi::Array<Stmt> seq, Span span = Span());
373 
375  size_t size() const { return operator->()->size(); }
379  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
400  template <typename... Args>
401  static Stmt Flatten(Args&&... seq_args) {
402  ffi::Array<Stmt> seq;
403 
404  ffi::details::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
405 
406  if (seq.empty()) {
407  return Evaluate(0);
408  } else if (seq.size() == 1) {
409  return seq[0];
410  }
411 
412  // If the argument is a single SeqStmt argument with no
413  // flattening or unwrapping required, then we may
414  // return the SeqStmt as-is.
415  if constexpr (sizeof...(seq_args) == 1) {
416  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
417  SeqStmt original = opt.value();
418  bool all_same = [&]() {
419  if (original->seq.size() != seq.size()) {
420  return false;
421  }
422  for (size_t i = 0; i < seq.size(); i++) {
423  if (!original->seq[i].same_as(seq[i])) {
424  return false;
425  }
426  }
427  return true;
428  }();
429  if (all_same) {
430  return original;
431  }
432  }
433  }
434 
435  return SeqStmt(seq);
436  }
438  class Flattener {
439  public:
440  explicit Flattener(ffi::Array<Stmt>* seq) : seq_(seq) {}
441 
442  template <typename T>
443  static ffi::Optional<SeqStmt> AsSeqStmt(const T& t) {
444  if constexpr (std::is_same_v<T, SeqStmt>) {
445  return t;
446  }
447  if constexpr (!std::is_base_of_v<T, SeqStmt>) {
448  return std::nullopt;
449  }
450  if constexpr (std::is_base_of_v<Stmt, T>) {
451  if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) {
452  return ffi::GetRef<SeqStmt>(ptr);
453  } else {
454  return std::nullopt;
455  }
456  }
457  return std::nullopt;
458  }
459 
460  template <typename T>
461  void operator()(size_t i, const T& stmt_or_seq) const {
462  if constexpr (std::is_base_of_v<ObjectRef, T>) {
463  // Early bail-out, applicable to any ObjectRef
464  if (!stmt_or_seq.defined()) {
465  return;
466  }
467  }
468 
469  if constexpr (std::is_same_v<T, SeqStmt>) {
470  // Static type-checking for a SeqStmt that could be flattened.
471  (*this)(0, stmt_or_seq->seq);
472  return;
473  }
474 
475  if constexpr (std::is_base_of_v<T, SeqStmt>) {
476  // Dynamic type-checking for a SeqStmt that could be
477  // flattened.
478  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
479  operator()(0, op->seq);
480  return;
481  }
482  }
483 
484  if constexpr (std::is_base_of_v<T, Evaluate>) {
485  // Evaluate(0) is used to represent a no-op, and may be
486  // generated by previous calls to SeqStmt::Flatten(). These
487  // should be removed to ensure that Flatten(a+b) is equivalent
488  // to Flatten(Flatten(a), Flatten(b)).
489  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
490  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
491  return;
492  }
493  }
494  }
495 
496  if constexpr (std::is_base_of_v<Stmt, T>) {
497  // Any other Stmt type just gets appended.
498  seq_->push_back(stmt_or_seq);
499  } else {
500  // Anything else is treated as an iterable of Stmt.
501  for (auto v : stmt_or_seq) {
502  this->operator()(0, v);
503  }
504  }
505  }
506 
507  private:
508  ffi::Array<Stmt>* seq_;
509  };
510 
513 };
514 
518 class IfThenElseNode : public StmtNode {
519  public:
525  ffi::Optional<Stmt> else_case;
526 
527  static void RegisterReflection() {
528  namespace refl = tvm::ffi::reflection;
529  refl::ObjectDef<IfThenElseNode>()
530  .def_ro("condition", &IfThenElseNode::condition)
531  .def_ro("then_case", &IfThenElseNode::then_case)
532  .def_ro("else_case", &IfThenElseNode::else_case);
533  }
535 };
536 
541 class IfThenElse : public Stmt {
542  public:
543  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case,
544  ffi::Optional<Stmt> else_case = std::nullopt, Span span = Span());
545 
548 };
549 
557 enum class ForKind : int {
559  kSerial = 0,
561  kParallel = 1,
566  kVectorized = 2,
568  kUnrolled = 3,
575  kThreadBinding = 4
576 };
577 
588 class ForNode : public StmtNode {
589  public:
604  ffi::Optional<IterVar> thread_binding;
613  ffi::Map<ffi::String, ffi::Any> annotations;
617  ffi::Optional<PrimExpr> step;
618 
619  static void RegisterReflection() {
620  namespace refl = tvm::ffi::reflection;
621  refl::ObjectDef<ForNode>()
622  .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef())
623  .def_ro("min", &ForNode::min)
624  .def_ro("extent", &ForNode::extent)
625  .def_ro("kind", &ForNode::kind)
626  .def_ro("body", &ForNode::body)
627  .def_ro("thread_binding", &ForNode::thread_binding)
628  .def_ro("annotations", &ForNode::annotations)
629  .def_ro("step", &ForNode::step);
630  }
631 
633  bool HasTrivialStep() const;
634 
636 };
637 
642 class For : public Stmt {
643  public:
644  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
645  ffi::Optional<IterVar> thread_binding = std::nullopt,
646  ffi::Map<ffi::String, ffi::Any> annotations = {},
647  ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
648 
651 };
652 
663 class WhileNode : public StmtNode {
664  public:
669 
670  static void RegisterReflection() {
671  namespace refl = tvm::ffi::reflection;
672  refl::ObjectDef<WhileNode>()
673  .def_ro("condition", &WhileNode::condition)
674  .def_ro("body", &WhileNode::body);
675  }
677 };
678 
683 class While : public Stmt {
684  public:
685  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
686 
689 };
690 
694 class BreakNode : public StmtNode {
695  public:
696  static void RegisterReflection() {
697  namespace refl = tvm::ffi::reflection;
698  refl::ObjectDef<BreakNode>();
699  }
700 
702 };
703 
708 class Break : public Stmt {
709  public:
710  TVM_DLL explicit Break(Span span);
711 
714 };
715 
719 class ContinueNode : public StmtNode {
720  public:
721  static void RegisterReflection() {
722  namespace refl = tvm::ffi::reflection;
723  refl::ObjectDef<ContinueNode>();
724  }
725 
727 };
728 
733 class Continue : public Stmt {
734  public:
735  TVM_DLL explicit Continue(Span span);
736 
739 };
740 
745  public:
749  ffi::Array<Range> region;
750 
751  static void RegisterReflection() {
752  namespace refl = tvm::ffi::reflection;
753  refl::ObjectDef<BufferRegionNode>()
754  .def_ro("buffer", &BufferRegionNode::buffer)
755  .def_ro("region", &BufferRegionNode::region);
756  }
757 
758  TVM_DLL PrimExpr ToPrimExpr() const final;
759 
760  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
762 };
763 
769  public:
770  TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array<Range> region);
771 
778 
785  TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array<PrimExpr> indices);
786 
789 };
790 
800 class MatchBufferRegionNode : public ffi::Object {
801  public:
806 
807  static void RegisterReflection() {
808  namespace refl = tvm::ffi::reflection;
809  refl::ObjectDef<MatchBufferRegionNode>()
810  .def_ro("buffer", &MatchBufferRegionNode::buffer)
811  .def_ro("source", &MatchBufferRegionNode::source);
812  }
813 
814  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
815  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.MatchBufferRegion", MatchBufferRegionNode, ffi::Object);
816 };
817 
822 class MatchBufferRegion : public ffi::ObjectRef {
823  public:
824  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
825 
829 };
830 
852 class SBlockNode : public StmtNode {
853  public:
855  ffi::Array<IterVar> iter_vars;
857  ffi::Array<BufferRegion> reads;
859  ffi::Array<BufferRegion> writes;
861  ffi::String name_hint;
863  ffi::Array<Buffer> alloc_buffers;
865  ffi::Array<MatchBufferRegion> match_buffers;
867  ffi::Map<ffi::String, ffi::Any> annotations;
875  ffi::Optional<Stmt> init;
878 
879  static void RegisterReflection() {
880  namespace refl = tvm::ffi::reflection;
881  refl::ObjectDef<SBlockNode>()
882  .def_ro("iter_vars", &SBlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef())
883  .def_ro("reads", &SBlockNode::reads)
884  .def_ro("writes", &SBlockNode::writes)
885  .def_ro("name_hint", &SBlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
886  .def_ro("alloc_buffers", &SBlockNode::alloc_buffers)
887  .def_ro("match_buffers", &SBlockNode::match_buffers)
888  .def_ro("annotations", &SBlockNode::annotations)
889  .def_ro("init", &SBlockNode::init)
890  .def_ro("body", &SBlockNode::body);
891  }
893 };
894 
899 class SBlock : public Stmt {
900  public:
901  TVM_DLL explicit SBlock(
902  ffi::Array<IterVar> iter_vars, ffi::Array<BufferRegion> reads,
903  ffi::Array<BufferRegion> writes, ffi::String name_hint, Stmt body,
904  ffi::Optional<Stmt> init = std::nullopt,
905  ffi::Array<Buffer> alloc_buffers = ffi::Array<Buffer>(),
906  ffi::Array<MatchBufferRegion> match_buffers = ffi::Array<MatchBufferRegion>(),
907  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
908  Span span = Span());
909 
910  TVM_DLL explicit SBlock(ffi::String name_hint, Stmt body,
911  ffi::Array<Buffer> alloc_buffers = ffi::Array<Buffer>(),
912  Span span = Span());
913 
916 };
917 
921 class SBlockRealizeNode : public StmtNode {
922  public:
924  ffi::Array<PrimExpr> iter_values;
932 
933  static void RegisterReflection() {
934  namespace refl = tvm::ffi::reflection;
935  refl::ObjectDef<SBlockRealizeNode>()
936  .def_ro("iter_values", &SBlockRealizeNode::iter_values)
937  .def_ro("predicate", &SBlockRealizeNode::predicate)
938  .def_ro("block", &SBlockRealizeNode::block);
939  }
941 };
942 
947 class SBlockRealize : public Stmt {
948  public:
949  TVM_DLL explicit SBlockRealize(ffi::Array<PrimExpr> iter_values, PrimExpr predicate, SBlock block,
950  Span span = Span());
951 
954 };
955 
969 class ExecScopeStmtNode : public StmtNode {
970  public:
975 
976  static void RegisterReflection() {
977  namespace refl = tvm::ffi::reflection;
978  refl::ObjectDef<ExecScopeStmtNode>()
979  .def_ro("exec_scope", &ExecScopeStmtNode::exec_scope)
980  .def_ro("body", &ExecScopeStmtNode::body);
981  }
983 };
984 
989 class ExecScopeStmt : public Stmt {
990  public:
991  TVM_DLL ExecScopeStmt(ExecScope exec_scope, Stmt body, Span span = Span());
992 
995 };
996 
998 namespace attr {
1000 constexpr const char* buffer_bound = "buffer_bound";
1005 constexpr const char* compute_scope = "compute_scope";
1007 constexpr const char* device_id = "device_id";
1009 constexpr const char* device_scope = "device_scope";
1011 constexpr const char* device_type = "device_type";
1017 constexpr const char* extern_scope = "extern_scope";
1019 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1021 constexpr const char* pragma_import_c = "pragma_import_c";
1023 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1025 constexpr const char* pragma_scope_prefix = "pragma_";
1027 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1029 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1031 constexpr const char* storage_alignment = "storage_alignment";
1033 constexpr const char* thread_extent = "thread_extent";
1035 constexpr const char* kVolatile = "tirx.volatile";
1042 constexpr const char* layout_transforms = "layout_transforms";
1050 constexpr const char* axis_separators = "axis_separators";
1054 constexpr const char* double_buffer_scope = "double_buffer_scope";
1058 constexpr const char* double_buffer_write = "double_buffer_write";
1060 constexpr const char* scan_update_scope = "scan_update_scope";
1062 constexpr const char* scan_init_scope = "scan_init_scope";
1069 constexpr const char* buffer_dim_align = "buffer_dim_align";
1071 constexpr const char* buffer_data_alignment = "buffer_data_alignment";
1073 constexpr const char* buffer_allocated_addr = "buffer_allocated_addr";
1083 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1084 // Pipeline related attributes
1086 constexpr const char* channel_read_scope = "channel_read_scope";
1088 constexpr const char* channel_read_advance = "channel_read_advance";
1090 constexpr const char* channel_write_scope = "channel_write_scope";
1092 constexpr const char* channel_write_advance = "channel_write_advance";
1094 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1096 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1097 
1101 constexpr const char* async_scope = "async_scope";
1102 
1120 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1121 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1122 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1123 
1127 constexpr const char* fragment_shape = "fragment_shape";
1128 
1132 constexpr const char* fragment_layout = "fragment_layout";
1133 
1137 constexpr const char* hand_threaded = "hand_threaded";
1138 
1146 constexpr const char* script_parsing_detect_access = "tirx.script_parsing_detect_access";
1147 
1151 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1152 
1154 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1155 
1157 constexpr const char* software_pipeline_order = "software_pipeline_order";
1158 
1163 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1164 
1166 constexpr const char* layout_free_buffers = "layout_free_buffers";
1167 
1169 constexpr const char* manifest_shared_memory_local_stage =
1170  "tirx.manifest_shared_memory_local_stage";
1171 
1173 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1174 
1179 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1180 
1183  "meta_schedule.thread_extent_low_inclusive";
1184 
1187  "meta_schedule.thread_extent_high_inclusive";
1188 
1191  "meta_schedule.random_compute_producer";
1192 
1194 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1195 
1197 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1198 
1200 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1201 
1203 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1204 
1206 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1207 
1209 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1213 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1214 
1218 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1219 
1221 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1222 
1229 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1230 
1232 constexpr const int meta_schedule_cache_type_read = 0;
1233 
1235 constexpr const int meta_schedule_cache_type_write = 1;
1236 
1238 constexpr const char* auto_copy = "auto_copy";
1239 
1241 constexpr const char* local_stage = "local_stage";
1242 
1244 constexpr const char* vector_bytes = "vector_bytes";
1245 
1250 constexpr const char* warp_execution = "warp_execution";
1251 
1253 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1254 
1258 constexpr const char* explicit_read_region = "explicit_read_region";
1259 
1263 constexpr const char* explicit_write_region = "explicit_write_region";
1264 constexpr const char* tensorized_nki_instruction = "tensorized_nki_instruction";
1265 
1267 constexpr const char* irregular_loop_mark = "irregular_loop_mark";
1268 
1272 constexpr const char* kPersistentKernel = "tirx.persistent_kernel";
1273 
1279 inline bool IsPragmaKey(const std::string& attr_key) {
1280  return attr_key.compare(0, 7, "pragma_") == 0;
1281 }
1282 
1283 } // namespace attr
1290 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1291 
1292 // overload printing of for type.
1293 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1294 
1295 // inline implementations
1296 inline const char* ForKind2String(ForKind t) {
1297  switch (t) {
1298  case ForKind::kSerial:
1299  return "serial";
1300  case ForKind::kParallel:
1301  return "parallel";
1302  case ForKind::kVectorized:
1303  return "vectorized";
1304  case ForKind::kUnrolled:
1305  return "unroll";
1307  return "thread_binding";
1308  }
1309  TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t;
1310  TVM_FFI_UNREACHABLE();
1311 }
1312 
1313 } // namespace tirx
1314 } // namespace tvm
1315 #endif // TVM_TIR_STMT_H_
Constant integer literals in the program.
Definition: expr.h:494
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:156
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:167
Reference to PrimExprNode.
Definition: expr.h:126
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:45
Allocate a buffer and declare it in scope.
Definition: stmt.h:261
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:271
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AllocBuffer", AllocBufferNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:273
Buffer buffer
The buffer being allocated and declared.
Definition: stmt.h:264
Managed reference to AllocBufferNode.
Definition: stmt.h:283
AllocBuffer(Buffer buffer, ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
std::optional< int64_t > ConstantAllocationSize() const
If the buffer's shape is constant, return the total number of elements.
Definition: stmt.h:293
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocBufferNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllocBuffer, Stmt, AllocBufferNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:161
PrimExpr condition
Condition to be checked.
Definition: stmt.h:164
StringImm error_kind
The error kind, e.g. "RuntimeError", "TypeError", "ValueError".
Definition: stmt.h:166
static void RegisterReflection()
Definition: stmt.h:170
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AssertStmt", AssertStmtNode, StmtNode)
ffi::Array< StringImm > message_parts
Error message fragments, concatenated at runtime when assertion fails.
Definition: stmt.h:168
Managed reference to AssertStmtNode.
Definition: stmt.h:184
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, StringImm error_kind, ffi::Array< StringImm > message_parts, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
static void RegisterReflection()
Definition: stmt.h:128
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:124
ffi::Any node
this is attribute about certain node
Definition: stmt.h:120
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AttrStmt", AttrStmtNode, StmtNode)
ffi::String attr_key
the type key of the attribute
Definition: stmt.h:122
Stmt body
The body statement to be executed.
Definition: stmt.h:126
Managed reference to AttrStmtNode.
Definition: stmt.h:143
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:79
PrimExpr value
The value to bind to the variable.
Definition: stmt.h:84
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Bind", BindNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:86
Var var
The variable being bound.
Definition: stmt.h:82
Managed reference to BindNode.
Definition: stmt.h:99
Bind(Var var, PrimExpr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Bind, Stmt, BindNode)
A Break in control flow.
Definition: stmt.h:694
static void RegisterReflection()
Definition: stmt.h:696
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Break", BreakNode, StmtNode)
Managed reference to BreakNode.
Definition: stmt.h:708
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Break, Stmt, BreakNode)
Break(Span span)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BreakNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:744
ffi::Array< Range > region
The region array of the buffer region.
Definition: stmt.h:749
PrimExpr ToPrimExpr() const final
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferRegion", BufferRegionNode, PrimExprConvertibleNode)
static void RegisterReflection()
Definition: stmt.h:751
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:760
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:747
Managed reference to BufferRegionNode.
Definition: stmt.h:768
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferRegion, PrimExprConvertible, BufferRegionNode)
BufferRegion(Buffer buffer, ffi::Array< Range > region)
static BufferRegion FromPoint(Buffer buffer, ffi::Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:203
Buffer buffer
The buffer variable.
Definition: stmt.h:206
ffi::Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:210
static void RegisterReflection()
Definition: stmt.h:214
PrimExpr value
The value to be stored.
Definition: stmt.h:208
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferStore", BufferStoreNode, StmtNode)
ffi::Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:212
Managed reference to BufferStoreNode.
Definition: stmt.h:229
BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferStore, Stmt, BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
A Continue in control flow.
Definition: stmt.h:719
static void RegisterReflection()
Definition: stmt.h:721
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Continue", ContinueNode, StmtNode)
Managed reference to ContinueNode.
Definition: stmt.h:733
TVM_DEFINE_OBJECT_REF_COW_METHOD(ContinueNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Continue, Stmt, ContinueNode)
Declare a buffer that can be used in the body.
Definition: stmt.h:240
Buffer buffer
The buffer being declared.
Definition: stmt.h:243
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DeclBuffer", DeclBufferNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:245
Managed reference to DeclBufferNode.
Definition: stmt.h:253
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeclBuffer, Stmt, DeclBufferNode)
DeclBuffer(Buffer buffer, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:338
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:341
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Evaluate", EvaluateNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:343
Managed reference to EvaluateNode.
Definition: stmt.h:354
Evaluate(PrimExpr value, Span span=Span())
Evaluate(int value, Span span=Span())
Definition: stmt.h:358
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Evaluate, Stmt, EvaluateNode)
A statement that annotates the execution scope for its body.
Definition: stmt.h:969
static void RegisterReflection()
Definition: stmt.h:976
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ExecScopeStmt", ExecScopeStmtNode, StmtNode)
Stmt body
The body statement under this execution scope.
Definition: stmt.h:974
ExecScope exec_scope
The execution scope.
Definition: stmt.h:972
Managed reference to ExecScopeStmtNode.
Definition: stmt.h:989
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExecScopeStmt, Stmt, ExecScopeStmtNode)
ExecScopeStmt(ExecScope exec_scope, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecScopeStmtNode)
Definition: exec_scope.h:234
A for loop, with possible type annotations.
Definition: stmt.h:588
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:593
ffi::Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:604
ffi::Optional< PrimExpr > step
The loop step. It is one if not specified.
Definition: stmt.h:617
Var loop_var
The loop variable.
Definition: stmt.h:591
bool HasTrivialStep() const
Check it is a loop without nontrivial loop step.
Stmt body
The body of the for loop.
Definition: stmt.h:599
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:595
ForKind kind
The kind of the for loop.
Definition: stmt.h:597
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the loop.
Definition: stmt.h:613
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.For", ForNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:619
Managed reference to ForNode.
Definition: stmt.h:642
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional< IterVar > thread_binding=std::nullopt, ffi::Map< ffi::String, ffi::Any > annotations={}, ffi::Optional< PrimExpr > step=std::nullopt, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
IfThenElse statement.
Definition: stmt.h:518
static void RegisterReflection()
Definition: stmt.h:527
ffi::Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:525
PrimExpr condition
The condition.
Definition: stmt.h:521
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:523
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IfThenElse", IfThenElseNode, StmtNode)
Managed reference to IfThenElseNode.
Definition: stmt.h:541
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IfThenElse, Stmt, IfThenElseNode)
IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional< Stmt > else_case=std::nullopt, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:800
Buffer buffer
The target buffer.
Definition: stmt.h:803
BufferRegion source
The source buffer region.
Definition: stmt.h:805
static void RegisterReflection()
Definition: stmt.h:807
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.MatchBufferRegion", MatchBufferRegionNode, ffi::Object)
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:822
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchBufferRegion, ffi::ObjectRef, MatchBufferRegionNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:852
ffi::Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:863
Stmt body
The body of the block.
Definition: stmt.h:877
ffi::Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:857
static void RegisterReflection()
Definition: stmt.h:879
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlock", SBlockNode, StmtNode)
ffi::Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:865
ffi::String name_hint
The name_hint of the block.
Definition: stmt.h:861
ffi::Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:875
ffi::Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:855
ffi::Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:859
ffi::Map< ffi::String, ffi::Any > annotations
The annotation of the block.
Definition: stmt.h:867
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:921
static void RegisterReflection()
Definition: stmt.h:933
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:929
ffi::Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:924
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlockRealize", SBlockRealizeNode, StmtNode)
SBlock block
The block to be realized.
Definition: stmt.h:931
Managed reference to BlockRealizeNode.
Definition: stmt.h:947
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode)
SBlockRealize(ffi::Array< PrimExpr > iter_values, PrimExpr predicate, SBlock block, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlockRealize, Stmt, SBlockRealizeNode)
Managed reference to SBlockNode.
Definition: stmt.h:899
SBlock(ffi::String name_hint, Stmt body, ffi::Array< Buffer > alloc_buffers=ffi::Array< Buffer >(), Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlock, Stmt, SBlockNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockNode)
SBlock(ffi::Array< IterVar > iter_vars, ffi::Array< BufferRegion > reads, ffi::Array< BufferRegion > writes, ffi::String name_hint, Stmt body, ffi::Optional< Stmt > init=std::nullopt, ffi::Array< Buffer > alloc_buffers=ffi::Array< Buffer >(), ffi::Array< MatchBufferRegion > match_buffers=ffi::Array< MatchBufferRegion >(), ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:313
size_t size() const
Definition: stmt.h:319
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:323
ffi::Array< Stmt > seq
internal sequence content.
Definition: stmt.h:316
static void RegisterReflection()
Definition: stmt.h:325
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SeqStmt", SeqStmtNode, StmtNode)
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:438
static ffi::Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:443
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:461
Flattener(ffi::Array< Stmt > *seq)
Definition: stmt.h:440
Sequence statement.
Definition: stmt.h:365
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:379
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:401
size_t size() const
Definition: stmt.h:375
SeqStmt(ffi::Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqStmt, Stmt, SeqStmtNode)
Base node of all statements.
Definition: stmt.h:42
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:62
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:60
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:48
StmtNode(Span span)
Definition: stmt.h:51
TVM_FFI_DECLARE_OBJECT_INFO("tirx.Stmt", StmtNode, ffi::Object)
static void RegisterReflection()
Definition: stmt.h:53
Container of all statements.
Definition: stmt.h:67
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ffi::ObjectRef, StmtNode)
Managed reference to StringImmNode.
Definition: expr.h:69
a named variable in TIR
Definition: var.h:77
A While loop.
Definition: stmt.h:663
static void RegisterReflection()
Definition: stmt.h:670
PrimExpr condition
The termination condition.
Definition: stmt.h:666
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.While", WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:668
Managed reference to WhileNode.
Definition: stmt.h:683
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(While, Stmt, WhileNode)
Printer class to print repr string of each AST/IR nodes.
Definition of layout.
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1137
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1146
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1154
constexpr const char * irregular_loop_mark
,ark a ForNode represent an irregular loop of non-structural control flow edges.
Definition: stmt.h:1267
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1132
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1169
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1197
constexpr const char * explicit_write_region
Mark that a block has an explicitly specified write region. This is used to override the default writ...
Definition: stmt.h:1263
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1232
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1090
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1238
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1033
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1279
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1054
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1163
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1229
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1203
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1182
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1250
constexpr const char * explicit_read_region
Mark that a block has an explicitly specified read region. This is used to override the default read ...
Definition: stmt.h:1258
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1157
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1029
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1031
constexpr const char * buffer_bound
Mark stores/loads with their bounds.
Definition: stmt.h:1000
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1179
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1062
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1166
constexpr const char * buffer_data_alignment
Mark buffer initial addr alignment in bytes.
Definition: stmt.h:1071
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1253
constexpr const char * kPersistentKernel
Mark the kernel as persistent.
Definition: stmt.h:1272
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1007
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1209
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1060
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1019
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1120
constexpr const char * tensorized_nki_instruction
Definition: stmt.h:1264
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1042
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1023
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1213
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1021
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1221
constexpr const char * compute_scope
Mark the scope as when computation start to happen. This can hint some code generator to create a new...
Definition: stmt.h:1005
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1101
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1027
constexpr const char * buffer_allocated_addr
Mark buffer allocated addr in bytes.
Definition: stmt.h:1073
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1083
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1200
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1121
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1206
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1086
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1194
constexpr const char * device_type
The device type.
Definition: stmt.h:1011
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1050
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1122
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1088
constexpr const char * kVolatile
Annotation key on AllocBuffer marking the allocation as volatile.
Definition: stmt.h:1035
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1009
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. Such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1017
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1058
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1096
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1241
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1127
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1094
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1235
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1092
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1025
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1069
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1151
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1173
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1190
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1186
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1244
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1218
@ kUnrolled
The execution is unrolled.
Definition: var.h:233
@ kVectorized
The loop is vectorized.
Definition: var.h:237
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:557
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
const char * ForKind2String(ForKind t)
Definition: stmt.h:1296
const Op & min()
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
TIR expressions.