tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
50 
51  static constexpr const char* _type_key = "tir.Stmt";
52  static constexpr const bool _type_has_method_sequal_reduce = true;
53  static constexpr const bool _type_has_method_shash_reduce = true;
54  static constexpr const uint32_t _type_child_slots = 15;
56 };
57 
59 class Stmt : public ObjectRef {
60  public:
62 };
63 
67 class LetStmtNode : public StmtNode {
68  public:
75 
77  v->Visit("var", &var);
78  v->Visit("value", &value);
79  v->Visit("body", &body);
80  v->Visit("span", &span);
81  }
82 
83  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
84  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
85  equal(body, other->body);
86  }
87 
88  void SHashReduce(SHashReducer hash_reduce) const {
89  hash_reduce.DefHash(var);
90  hash_reduce(value);
91  hash_reduce(body);
92  }
93 
94  static constexpr const char* _type_key = "tir.LetStmt";
96 };
97 
102 class LetStmt : public Stmt {
103  public:
104  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
105 
108 };
109 
120 class AttrStmtNode : public StmtNode {
121  public:
130 
132  v->Visit("node", &node);
133  v->Visit("attr_key", &attr_key);
134  v->Visit("value", &value);
135  v->Visit("body", &body);
136  v->Visit("span", &span);
137  }
138 
139  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
140  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
141  equal(value, other->value) && equal(body, other->body);
142  }
143 
144  void SHashReduce(SHashReducer hash_reduce) const {
145  hash_reduce(node);
146  hash_reduce(attr_key);
147  hash_reduce(value);
148  hash_reduce(body);
149  }
150 
151  static constexpr const char* _type_key = "tir.AttrStmt";
153 };
154 
159 class AttrStmt : public Stmt {
160  public:
161  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
162 
165 };
166 
170 class AssertStmtNode : public StmtNode {
171  public:
181 
183  v->Visit("condition", &condition);
184  v->Visit("message", &message);
185  v->Visit("body", &body);
186  v->Visit("span", &span);
187  }
188 
189  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
190  return equal(condition, other->condition) && equal(message, other->message) &&
191  equal(body, other->body);
192  }
193 
194  void SHashReduce(SHashReducer hash_reduce) const {
195  hash_reduce(condition);
196  hash_reduce(message);
197  hash_reduce(body);
198  }
199 
200  static constexpr const char* _type_key = "tir.AssertStmt";
202 };
203 
208 class AssertStmt : public Stmt {
209  public:
210  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
211 
214 };
215 
226 class BufferStoreNode : public StmtNode {
227  public:
234 
236  v->Visit("buffer", &buffer);
237  v->Visit("value", &value);
238  v->Visit("indices", &indices);
239  v->Visit("span", &span);
240  }
241 
242  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
243  return equal(buffer, other->buffer) && equal(value, other->value) &&
244  equal(indices, other->indices);
245  }
246 
247  void SHashReduce(SHashReducer hash_reduce) const {
248  hash_reduce(buffer);
249  hash_reduce(value);
250  hash_reduce(indices);
251  }
252 
253  static constexpr const char* _type_key = "tir.BufferStore";
255 };
256 
261 class BufferStore : public Stmt {
262  public:
263  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
264  Span span = Span());
265 
268 };
269 
281 class BufferRealizeNode : public StmtNode {
282  public:
291 
293  v->Visit("buffer", &buffer);
294  v->Visit("bounds", &bounds);
295  v->Visit("condition", &condition);
296  v->Visit("body", &body);
297  v->Visit("span", &span);
298  }
299 
300  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
301  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
302  equal(condition, other->condition) && equal(body, other->body);
303  }
304 
305  void SHashReduce(SHashReducer hash_reduce) const {
306  hash_reduce(buffer);
307  hash_reduce(bounds);
308  hash_reduce(condition);
309  hash_reduce(body);
310  }
311 
312  BufferRealizeNode() = default;
314  Span span = Span())
316 
317  static constexpr const char* _type_key = "tir.BufferRealize";
319 };
320 
325 class BufferRealize : public Stmt {
326  public:
327  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
328  Span span = Span());
329 
332 };
333 
344 class ProducerStoreNode : public StmtNode {
345  public:
352 
354  v->Visit("producer", &producer);
355  v->Visit("value", &value);
356  v->Visit("indices", &indices);
357  v->Visit("span", &span);
358  }
359 
360  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
361  return equal(producer, other->producer) && equal(value, other->value) &&
362  equal(indices, other->indices);
363  }
364 
365  void SHashReduce(SHashReducer hash_reduce) const {
366  hash_reduce(producer);
367  hash_reduce(value);
368  hash_reduce(indices);
369  }
370 
371  static constexpr const char* _type_key = "tir.ProducerStore";
373 };
374 
379 class ProducerStore : public Stmt {
380  public:
381  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
382  Span span = Span());
383 
386 };
387 
400  public:
411 
413  v->Visit("producer", &producer);
414  v->Visit("bounds", &bounds);
415  v->Visit("condition", &condition);
416  v->Visit("body", &body);
417  v->Visit("storage_scope", &storage_scope);
418  v->Visit("span", &span);
419  }
420 
422  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
423  equal(condition, other->condition) && equal(body, other->body) &&
425  }
426 
427  void SHashReduce(SHashReducer hash_reduce) const {
428  hash_reduce(producer);
429  hash_reduce(bounds);
430  hash_reduce(condition);
431  hash_reduce(body);
432  hash_reduce(storage_scope);
433  }
434 
435  static constexpr const char* _type_key = "tir.ProducerRealize";
437 };
438 
443 class ProducerRealize : public Stmt {
444  public:
445  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
446  String storage_scope = "", Span span = Span());
447 
450 };
451 
455 class AllocateNode : public StmtNode {
456  public:
474 
476  v->Visit("buffer_var", &buffer_var);
477  v->Visit("dtype", &dtype);
478  v->Visit("extents", &extents);
479  v->Visit("condition", &condition);
480  v->Visit("body", &body);
481  v->Visit("annotations", &annotations);
482  v->Visit("span", &span);
483  }
484 
485  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
486  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
487  equal(extents, other->extents) && equal(condition, other->condition) &&
488  equal(body, other->body) && equal(annotations, other->annotations);
489  }
490 
491  void SHashReduce(SHashReducer hash_reduce) const {
492  hash_reduce.DefHash(buffer_var);
493  hash_reduce(dtype);
494  hash_reduce(extents);
495  hash_reduce(condition);
496  hash_reduce(body);
497  hash_reduce(annotations);
498  }
499 
512  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
513 
514  static constexpr const char* _type_key = "tir.Allocate";
515  static constexpr const bool _type_has_method_sequal_reduce = true;
516  static constexpr const bool _type_has_method_shash_reduce = true;
518 };
519 
524 class Allocate : public Stmt {
525  public:
526  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
527  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
528  Span span = Span());
529 
532 };
533 
537 class AllocateConstNode : public StmtNode {
538  public:
562 
564  v->Visit("buffer_var", &buffer_var);
565  v->Visit("data", &data);
566  v->Visit("irmod_storage_idx", &irmod_storage_idx);
567  v->Visit("dtype", &dtype);
568  v->Visit("extents", &extents);
569  v->Visit("body", &body);
570  v->Visit("annotations", &annotations);
571  v->Visit("span", &span);
572  }
573 
574  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
575  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
576  equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
577  equal(annotations, other->annotations);
578  }
579 
580  void SHashReduce(SHashReducer hash_reduce) const {
581  hash_reduce.DefHash(buffer_var);
582  hash_reduce(dtype);
583  hash_reduce(extents);
584  hash_reduce(body);
585  hash_reduce(annotations);
586  hash_reduce(data);
587  }
588 
601  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
602 
603  static constexpr const char* _type_key = "tir.AllocateConst";
604  static constexpr const bool _type_has_method_sequal_reduce = true;
605  static constexpr const bool _type_has_method_shash_reduce = true;
607 };
608 
613 class AllocateConst : public Stmt {
614  public:
615  /* The constructor to create a IRNode with constant data
616  * depending on the type of ObjectRef, it will either
617  * create AllocateConstNode with irmod_storage_idx or data
618  */
619  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
620  ObjectRef data_or_idx, Stmt body,
622  Span span = Span());
625 };
626 
628 class DeclBufferNode : public StmtNode {
629  public:
634 
636  v->Visit("buffer", &buffer);
637  v->Visit("body", &body);
638  v->Visit("span", &span);
639  }
640 
641  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
642  return equal(buffer, other->buffer) && equal(body, other->body);
643  }
644 
645  void SHashReduce(SHashReducer hash_reduce) const {
646  hash_reduce(buffer);
647  hash_reduce(body);
648  }
649 
650  static constexpr const char* _type_key = "tir.DeclBuffer";
652 };
653 
655 class DeclBuffer : public Stmt {
656  public:
657  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
660 };
661 
666 class SeqStmtNode : public StmtNode {
667  public:
670 
672  size_t size() const { return seq.size(); }
676  Stmt operator[](size_t index) const { return seq[index]; }
677 
679  v->Visit("seq", &seq);
680  v->Visit("span", &span);
681  }
682 
683  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
684  return equal(seq, other->seq);
685  }
686 
687  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
688 
689  static constexpr const char* _type_key = "tir.SeqStmt";
691 };
692 
699 class EvaluateNode : public StmtNode {
700  public:
703 
705  v->Visit("value", &value);
706  v->Visit("span", &span);
707  }
708 
709  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
710  return equal(value, other->value);
711  }
712 
713  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
714 
715  static constexpr const char* _type_key = "tir.Evaluate";
717 };
718 
723 class Evaluate : public Stmt {
724  public:
725  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
726 
727  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
728 
731 };
732 
734 class SeqStmt : public Stmt {
735  public:
741  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
742 
744  size_t size() const { return operator->()->size(); }
748  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
769  template <typename... Args>
770  static Stmt Flatten(Args&&... seq_args) {
771  Array<Stmt> seq;
772  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
773 
774  if (seq.empty()) {
775  return Evaluate(0);
776  } else if (seq.size() == 1) {
777  return seq[0];
778  }
779 
780  // If the argument is a single SeqStmt argument with no
781  // flattening or unwrapping required, then we may
782  // return the SeqStmt as-is.
783  if constexpr (sizeof...(seq_args) == 1) {
784  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
785  SeqStmt original = opt.value();
786  bool all_same = [&]() {
787  if (original->seq.size() != seq.size()) {
788  return false;
789  }
790  for (size_t i = 0; i < seq.size(); i++) {
791  if (!original->seq[i].same_as(seq[i])) {
792  return false;
793  }
794  }
795  return true;
796  }();
797  if (all_same) {
798  return original;
799  }
800  }
801  }
802 
803  return SeqStmt(seq);
804  }
806  class Flattener {
807  public:
808  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
809 
810  template <typename T>
811  static Optional<SeqStmt> AsSeqStmt(const T& t) {
812  if constexpr (std::is_same_v<T, SeqStmt>) {
813  return t;
814  } else if constexpr (!std::is_base_of_v<T, SeqStmt>) {
815  return NullOpt;
816  } else if (auto* ptr = t.template as<SeqStmtNode>()) {
817  return GetRef<SeqStmt>(ptr);
818  } else {
819  return NullOpt;
820  }
821  }
822 
823  template <typename T>
824  void operator()(size_t i, const T& stmt_or_seq) const {
825  if constexpr (std::is_base_of_v<ObjectRef, T>) {
826  // Early bail-out, applicable to any ObjectRef
827  if (!stmt_or_seq.defined()) {
828  return;
829  }
830  }
831 
832  if constexpr (std::is_same_v<T, SeqStmt>) {
833  // Static type-checking for a SeqStmt that could be flattened.
834  (*this)(0, stmt_or_seq->seq);
835  return;
836  }
837 
838  if constexpr (std::is_base_of_v<T, SeqStmt>) {
839  // Dynamic type-checking for a SeqStmt that could be
840  // flattened.
841  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
842  operator()(0, op->seq);
843  return;
844  }
845  }
846 
847  if constexpr (std::is_base_of_v<T, Evaluate>) {
848  // Evaluate(0) is used to represent a no-op, and may be
849  // generated by previous calls to SeqStmt::Flatten(). These
850  // should be removed to ensure that Flatten(a+b) is equivalent
851  // to Flatten(Flatten(a), Flatten(b)).
852  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
853  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
854  return;
855  }
856  }
857  }
858 
859  if constexpr (std::is_base_of_v<Stmt, T>) {
860  // Any other Stmt type just gets appended.
861  seq_->push_back(stmt_or_seq);
862  } else {
863  // Anything else is treated as an iterable of Stmt.
864  for (auto v : stmt_or_seq) {
865  this->operator()(0, v);
866  }
867  }
868  }
869 
870  private:
871  Array<Stmt>* seq_;
872  };
873 
876 };
877 
881 class IfThenElseNode : public StmtNode {
882  public:
889 
891  v->Visit("condition", &condition);
892  v->Visit("then_case", &then_case);
893  v->Visit("else_case", &else_case);
894  v->Visit("span", &span);
895  }
896 
897  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
898  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
899  equal(else_case, other->else_case);
900  }
901 
902  void SHashReduce(SHashReducer hash_reduce) const {
903  hash_reduce(condition);
904  hash_reduce(then_case);
905  hash_reduce(else_case);
906  }
907 
908  static constexpr const char* _type_key = "tir.IfThenElse";
910 };
911 
916 class IfThenElse : public Stmt {
917  public:
918  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
919  Span span = Span());
920 
923 };
924 
932 enum class ForKind : int {
934  kSerial = 0,
936  kParallel = 1,
941  kVectorized = 2,
943  kUnrolled = 3,
950  kThreadBinding = 4
951 };
952 
963 class ForNode : public StmtNode {
964  public:
989 
991  v->Visit("loop_var", &loop_var);
992  v->Visit("min", &min);
993  v->Visit("extent", &extent);
994  v->Visit("kind", &kind);
995  v->Visit("body", &body);
996  v->Visit("thread_binding", &thread_binding);
997  v->Visit("annotations", &annotations);
998  v->Visit("span", &span);
999  }
1000 
1001  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
1002  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
1003  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
1005  }
1006 
1007  void SHashReduce(SHashReducer hash_reduce) const {
1008  hash_reduce.DefHash(loop_var);
1009  hash_reduce(min);
1010  hash_reduce(extent);
1011  hash_reduce(kind);
1012  hash_reduce(body);
1013  hash_reduce(thread_binding);
1014  hash_reduce(annotations);
1015  }
1016 
1017  static constexpr const char* _type_key = "tir.For";
1019 };
1020 
1025 class For : public Stmt {
1026  public:
1027  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
1028  Optional<IterVar> thread_binding = NullOpt,
1029  Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());
1030 
1033 };
1034 
1045 class WhileNode : public StmtNode {
1046  public:
1051 
1053  v->Visit("condition", &condition);
1054  v->Visit("body", &body);
1055  v->Visit("span", &span);
1056  }
1057 
1058  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
1059  return equal(condition, other->condition) && equal(body, other->body);
1060  }
1061 
1062  void SHashReduce(SHashReducer hash_reduce) const {
1063  hash_reduce(condition);
1064  hash_reduce(body);
1065  }
1066 
1067  static constexpr const char* _type_key = "tir.While";
1069 };
1070 
1075 class While : public Stmt {
1076  public:
1077  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1078 
1081 };
1082 
1086 class PrefetchNode : public StmtNode {
1087  public:
1092 
1094  v->Visit("buffer", &buffer);
1095  v->Visit("bounds", &bounds);
1096  v->Visit("span", &span);
1097  }
1098 
1099  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1100  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1101  }
1102 
1103  void SHashReduce(SHashReducer hash_reduce) const {
1104  hash_reduce(buffer);
1105  hash_reduce(bounds);
1106  }
1107 
1108  PrefetchNode() = default;
1111 
1112  static constexpr const char* _type_key = "tir.Prefetch";
1114 };
1115 
1120 class Prefetch : public Stmt {
1121  public:
1122  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1123 
1126 };
1127 
1131 class BufferRegionNode : public Object {
1132  public:
1137 
1139  v->Visit("buffer", &buffer);
1140  v->Visit("region", &region);
1141  }
1142 
1143  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1144  return equal(buffer, other->buffer) && equal(region, other->region);
1145  }
1146 
1147  void SHashReduce(SHashReducer hash_reduce) const {
1148  hash_reduce(buffer);
1149  hash_reduce(region);
1150  }
1151 
1152  static constexpr const char* _type_key = "tir.BufferRegion";
1153  static constexpr const bool _type_has_method_sequal_reduce = true;
1154  static constexpr const bool _type_has_method_shash_reduce = true;
1156 };
1157 
1162 class BufferRegion : public ObjectRef {
1163  public:
1164  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1165 
1171  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1172 
1179  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1180 
1183 };
1184 
1195  public:
1200 
1202  v->Visit("buffer", &buffer);
1203  v->Visit("source", &source);
1204  }
1205 
1207  return equal(buffer, other->buffer) && equal(source, other->source);
1208  }
1209 
1210  void SHashReduce(SHashReducer hash_reduce) const {
1211  hash_reduce(buffer);
1212  hash_reduce(source);
1213  }
1214 
1215  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1216  static constexpr const bool _type_has_method_sequal_reduce = true;
1217  static constexpr const bool _type_has_method_shash_reduce = true;
1219 };
1220 
1226  public:
1227  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1228 
1231 };
1232 
1254 class BlockNode : public StmtNode {
1255  public:
1280 
1282  v->Visit("iter_vars", &iter_vars);
1283  v->Visit("reads", &reads);
1284  v->Visit("writes", &writes);
1285  v->Visit("name_hint", &name_hint);
1286  v->Visit("body", &body);
1287  v->Visit("init", &init);
1288  v->Visit("alloc_buffers", &alloc_buffers);
1289  v->Visit("match_buffers", &match_buffers);
1290  v->Visit("annotations", &annotations);
1291  }
1292 
1293  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1294  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1295  return equal.DefEqual(iter_vars, other->iter_vars) &&
1296  equal(alloc_buffers, other->alloc_buffers) &&
1297  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1298  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1299  equal(annotations, other->annotations);
1300  }
1301 
1302  void SHashReduce(SHashReducer hash_reduce) const {
1303  hash_reduce.DefHash(iter_vars);
1304  hash_reduce(alloc_buffers);
1305  hash_reduce(match_buffers);
1306  hash_reduce(reads);
1307  hash_reduce(writes);
1308  hash_reduce(body);
1309  hash_reduce(init);
1310  hash_reduce(annotations);
1311  }
1312 
1313  static constexpr const char* _type_key = "tir.Block";
1315 };
1316 
1321 class Block : public Stmt {
1322  public:
1323  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1324  Array<BufferRegion> writes, String name_hint, Stmt body,
1325  Optional<Stmt> init = NullOpt,
1326  Array<Buffer> alloc_buffers = Array<Buffer>(),
1329  Span span = Span());
1330 
1333 };
1334 
1338 class BlockRealizeNode : public StmtNode {
1339  public:
1349 
1351  v->Visit("iter_values", &iter_values);
1352  v->Visit("predicate", &predicate);
1353  v->Visit("block", &block);
1354  }
1355 
1356  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1357  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1358  equal(block, other->block);
1359  }
1360 
1361  void SHashReduce(SHashReducer hash_reduce) const {
1362  hash_reduce(iter_values);
1363  hash_reduce(predicate);
1364  hash_reduce(block);
1365  }
1366 
1367  static constexpr const char* _type_key = "tir.BlockRealize";
1369 };
1370 
1375 class BlockRealize : public Stmt {
1376  public:
1377  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1378  Span span = Span());
1379 
1382 };
1383 
1385 namespace attr {
1386 // The above attr does not pass to ir stage.
1388 constexpr const char* thread_extent = "thread_extent";
1390 constexpr const char* virtual_thread = "virtual_thread";
1392 constexpr const char* coproc_scope = "coproc_scope";
1397 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1399 constexpr const char* volatile_scope = "volatile_scope";
1405 constexpr const char* extern_scope = "extern_scope";
1410 constexpr const char* compute_scope = "compute_scope";
1412 constexpr const char* storage_alignment = "storage_alignment";
1414 constexpr const char* realize_scope = "realize_scope";
1416 constexpr const char* device_id = "device_id";
1418 constexpr const char* device_type = "device_type";
1420 constexpr const char* loop_scope = "loop_scope";
1422 constexpr const char* reduce_scope = "reduce_scope";
1424 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1426 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1428 constexpr const char* pragma_scope_prefix = "pragma_";
1430 constexpr const char* pragma_import_c = "pragma_import_c";
1432 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1434 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1439 constexpr const char* prefetch_scope = "prefetch_scope";
1446 constexpr const char* layout_transforms = "layout_transforms";
1454 constexpr const char* axis_separators = "axis_separators";
1458 constexpr const char* double_buffer_scope = "double_buffer_scope";
1462 constexpr const char* double_buffer_write = "double_buffer_write";
1464 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1466 constexpr const char* scan_update_scope = "scan_update_scope";
1468 constexpr const char* scan_init_scope = "scan_init_scope";
1475 constexpr const char* buffer_dim_align = "buffer_dim_align";
1477 constexpr const char* buffer_bound = "buffer_bound";
1487 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1488 // Pipeline related attributes
1490 constexpr const char* channel_read_scope = "channel_read_scope";
1492 constexpr const char* channel_read_advance = "channel_read_advance";
1494 constexpr const char* channel_write_scope = "channel_write_scope";
1496 constexpr const char* channel_write_advance = "channel_write_advance";
1498 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1500 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1501 
1505 constexpr const char* device_scope = "device_scope";
1506 
1510 constexpr const char* async_scope = "async_scope";
1511 
1529 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1530 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1531 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1532 
1536 constexpr const char* fragment_shape = "fragment_shape";
1537 
1541 constexpr const char* fragment_layout = "fragment_layout";
1542 
1546 constexpr const char* hand_threaded = "hand_threaded";
1547 
1555 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1556 
1560 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1561 
1563 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1564 
1566 constexpr const char* software_pipeline_order = "software_pipeline_order";
1567 
1572 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1573 
1575 constexpr const char* layout_free_buffers = "layout_free_buffers";
1576 
1578 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1579 
1581 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1582 
1587 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1588 
1591  "meta_schedule.thread_extent_low_inclusive";
1592 
1595  "meta_schedule.thread_extent_high_inclusive";
1596 
1599  "meta_schedule.random_compute_producer";
1600 
1602 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1603 
1605 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1606 
1608 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1609 
1611 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1612 
1614 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1615 
1617 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1621 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1622 
1626 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1627 
1629 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1630 
1637 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1638 
1640 constexpr const int meta_schedule_cache_type_read = 0;
1641 
1643 constexpr const int meta_schedule_cache_type_write = 1;
1644 
1646 constexpr const char* auto_copy = "auto_copy";
1647 
1649 constexpr const char* local_stage = "local_stage";
1650 
1652 constexpr const char* vector_bytes = "vector_bytes";
1653 
1658 constexpr const char* warp_execution = "warp_execution";
1659 
1661 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1662 
1668 inline bool IsPragmaKey(const std::string& attr_key) {
1669  return attr_key.compare(0, 7, "pragma_") == 0;
1670 }
1671 
1672 } // namespace attr
1679 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1680 
1681 // overload printing of for type.
1682 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1683 
1684 // inline implementations
1685 inline const char* ForKind2String(ForKind t) {
1686  switch (t) {
1687  case ForKind::kSerial:
1688  return "serial";
1689  case ForKind::kParallel:
1690  return "parallel";
1691  case ForKind::kVectorized:
1692  return "vectorized";
1693  case ForKind::kUnrolled:
1694  return "unroll";
1696  return "thread_binding";
1697  }
1698  LOG(FATAL) << "Unknown ForKind" << t;
1699 }
1700 
1701 } // namespace tir
1702 } // namespace tvm
1703 #endif // TVM_TIR_STMT_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:114
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187
Definition: source_map.h:120
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool empty() const
Definition: array.h:432
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:42
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:517
const Object * operator->() const
Definition: object.h:554
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:528
base class of all object containers.
Definition: object.h:169
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Allocate a buffer that can be used in body.
Definition: stmt.h:537
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:561
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:548
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:552
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:605
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:580
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:554
DataType dtype
The type of the buffer.
Definition: stmt.h:550
Var buffer_var
The buffer variable.
Definition: stmt.h:540
static constexpr const char * _type_key
Definition: stmt.h:603
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:574
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:594
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:604
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:563
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:543
Managed reference to AllocateConstNode.
Definition: stmt.h:613
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode)
AllocateConst(Var buffer_var, DataType dtype, Array< PrimExpr > extents, ObjectRef data_or_idx, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode)
Allocate a buffer that can be used in body.
Definition: stmt.h:455
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:485
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:462
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:473
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:464
Stmt body
The body to be executed.
Definition: stmt.h:466
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode)
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:505
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:516
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:515
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:475
DataType dtype
The type of the buffer.
Definition: stmt.h:460
static constexpr const char * _type_key
Definition: stmt.h:514
Var buffer_var
The buffer variable.
Definition: stmt.h:458
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:491
Managed reference to AllocateNode.
Definition: stmt.h:524
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
Allocate(Var buffer_var, DataType dtype, Array< PrimExpr > extents, PrimExpr condition, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:170
PrimExpr condition
Condition to be checked.
Definition: stmt.h:173
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:175
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:189
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:200
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:194
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:182
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:180
Managed reference to AssertStmtNode.
Definition: stmt.h:208
TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:120
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:127
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:144
Stmt body
The body statement to be executed.
Definition: stmt.h:129
String attr_key
the type key of the attribute
Definition: stmt.h:125
ObjectRef node
this is attribute about certain node
Definition: stmt.h:123
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:131
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode)
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:139
static constexpr const char * _type_key
Definition: stmt.h:151
Managed reference to AttrStmtNode.
Definition: stmt.h:159
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:1254
static constexpr const char * _type_key
Definition: stmt.h:1313
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1259
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1277
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1257
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1261
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1279
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:1273
String name_hint
The name_hint of the block.
Definition: stmt.h:1263
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1281
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1275
Stmt body
The body of the block.
Definition: stmt.h:1265
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1302
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1293
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode)
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1338
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1361
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1356
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:1367
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1341
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:1346
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1350
Block block
The block to be realized.
Definition: stmt.h:1348
Managed reference to BlockRealizeNode.
Definition: stmt.h:1375
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode)
BlockRealize(Array< PrimExpr > iter_values, PrimExpr predicate, Block block, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode)
Managed reference to BlockNode.
Definition: stmt.h:1321
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode)
Block(Array< IterVar > iter_vars, Array< BufferRegion > reads, Array< BufferRegion > writes, String name_hint, Stmt body, Optional< Stmt > init=NullOpt, Array< Buffer > alloc_buffers=Array< Buffer >(), Array< MatchBufferRegion > match_buffers=Array< MatchBufferRegion >(), Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode)
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:281
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:292
Buffer buffer
The buffer variable.
Definition: stmt.h:284
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:300
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:313
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:288
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:305
Stmt body
The body of realization.
Definition: stmt.h:290
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:286
static constexpr const char * _type_key
Definition: stmt.h:317
Managed reference to BufferRealizeNode.
Definition: stmt.h:325
BufferRealize(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1131
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1134
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1143
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1154
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1138
static constexpr const char * _type_key
Definition: stmt.h:1152
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1153
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1136
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1147
Managed reference to BufferRegionNode.
Definition: stmt.h:1162
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
BufferRegion(Buffer buffer, Array< Range > region)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:226
Buffer buffer
The buffer variable.
Definition: stmt.h:229
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:233
PrimExpr value
The value to be stored.
Definition: stmt.h:231
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:247
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:235
static constexpr const char * _type_key
Definition: stmt.h:253
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:242
Managed reference to BufferStoreNode.
Definition: stmt.h:261
BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:162
Managed reference to DataProducerNode.
Definition: buffer.h:295
Declare a buffer that can be used in the body.
Definition: stmt.h:628
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:645
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:635
static constexpr const char * _type_key
Definition: stmt.h:650
Buffer buffer
The buffer being declared.
Definition: stmt.h:631
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:633
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:641
Managed reference to DeclBufferNode.
Definition: stmt.h:655
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:699
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:713
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:704
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:715
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:702
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:709
Managed reference to EvaluateNode.
Definition: stmt.h:723
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:727
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode)
Evaluate(PrimExpr value, Span span=Span())
A for loop, with possible type annotations.
Definition: stmt.h:963
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:990
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:979
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:968
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1007
ForKind kind
The kind of the for loop.
Definition: stmt.h:972
static constexpr const char * _type_key
Definition: stmt.h:1017
Var loop_var
The loop variable.
Definition: stmt.h:966
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:970
Stmt body
The body of the for loop.
Definition: stmt.h:974
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode)
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:1001
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:988
Managed reference to ForNode.
Definition: stmt.h:1025
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional< IterVar > thread_binding=NullOpt, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:881
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:890
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:902
PrimExpr condition
The condition.
Definition: stmt.h:884
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:888
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:897
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:908
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:886
Managed reference to IfThenElseNode.
Definition: stmt.h:916
IfThenElse(PrimExpr condition, Stmt then_case, Optional< Stmt > else_case=NullOpt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode)
Let binding, bind var to value, then run body.
Definition: stmt.h:67
PrimExpr value
The value to be bound.
Definition: stmt.h:72
static constexpr const char * _type_key
Definition: stmt.h:94
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:76
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:83
Stmt body
The body block.
Definition: stmt.h:74
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode)
Var var
The variable.
Definition: stmt.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:88
Managed reference to LetStmtNode.
Definition: stmt.h:102
TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1194
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object)
static constexpr const char * _type_key
Definition: stmt.h:1215
Buffer buffer
The target buffer.
Definition: stmt.h:1197
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1217
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1201
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1206
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1216
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1210
BufferRegion source
The source buffer region.
Definition: stmt.h:1199
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1225
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
A prefetch hint for a buffer.
Definition: stmt.h:1086
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode)
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1093
static constexpr const char * _type_key
Definition: stmt.h:1112
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1109
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1091
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1103
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1089
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1099
Managed reference to PrefetchNode.
Definition: stmt.h:1120
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode)
Prefetch(Buffer buffer, Array< Range > bounds, Span span=Span())
Annotate the bounds where the data produced by the producer need to be written and read in body....
Definition: stmt.h:399
Stmt body
The body of realization.
Definition: stmt.h:408
DataProducer producer
The producer that produces the data.
Definition: stmt.h:402
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:406
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:427
static constexpr const char * _type_key
Definition: stmt.h:435
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:412
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:410
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:421
Region bounds
Bounds to be realized.
Definition: stmt.h:404
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode)
Managed reference to ProducerRealizeNode.
Definition: stmt.h:443
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode)
ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope="", Span span=Span())
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:344
PrimExpr value
The value to be stored.
Definition: stmt.h:349
DataProducer producer
The producer to store the results into.
Definition: stmt.h:347
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:365
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode)
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:351
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:353
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:360
static constexpr const char * _type_key
Definition: stmt.h:371
Managed reference to ProducerStoreNode.
Definition: stmt.h:379
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode)
ProducerStore(DataProducer producer, PrimExpr value, Array< PrimExpr > indices, Span span=Span())
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:666
size_t size() const
Definition: stmt.h:672
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:669
static constexpr const char * _type_key
Definition: stmt.h:689
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:678
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:683
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:687
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:676
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:806
Flattener(Array< Stmt > *seq)
Definition: stmt.h:808
static Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:811
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:824
Sequence statement.
Definition: stmt.h:734
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:744
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:748
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:770
SeqStmt(Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Base node of all statements.
Definition: stmt.h:38
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:54
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:52
StmtNode(Span span)
Definition: stmt.h:47
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:53
static constexpr const char * _type_key
Definition: stmt.h:51
Container of all statements.
Definition: stmt.h:59
TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:88
A While loop.
Definition: stmt.h:1045
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1052
static constexpr const char * _type_key
Definition: stmt.h:1067
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:1050
PrimExpr condition
The termination condition.
Definition: stmt.h:1048
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1062
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:1058
Managed reference to WhileNode.
Definition: stmt.h:1075
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
tvm::Span Span
Definition: base.h:65
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1410
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1487
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1566
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1608
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1546
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1475
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1492
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1399
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1614
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1498
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1578
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1430
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1426
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1581
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1563
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1598
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1658
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1505
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1668
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1388
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1555
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1629
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1390
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1405
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1637
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1422
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1611
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1494
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1646
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1464
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1529
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1416
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1590
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1617
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1652
constexpr const char * device_type
The device type.
Definition: stmt.h:1418
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1640
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1572
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1621
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1587
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1466
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1424
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1420
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1458
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1536
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1434
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1541
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1575
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1530
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1643
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:1392
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1477
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1439
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1531
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1446
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1454
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1414
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1490
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1594
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1496
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:1397
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1661
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1560
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1500
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1432
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1428
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1468
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1626
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1412
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1649
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1605
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1462
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1510
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1602
const char * ForKind2String(ForKind t)
Definition: stmt.h:1685
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:932
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:243
@ kUnrolled
The execution is unrolled.
Definition: var.h:239
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
TIR expressions.