tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
50 
51  static constexpr const char* _type_key = "tir.Stmt";
52  static constexpr const bool _type_has_method_sequal_reduce = true;
53  static constexpr const bool _type_has_method_shash_reduce = true;
54  static constexpr const uint32_t _type_child_slots = 15;
56 };
57 
59 class Stmt : public ObjectRef {
60  public:
62 };
63 
67 class LetStmtNode : public StmtNode {
68  public:
75 
77  v->Visit("var", &var);
78  v->Visit("value", &value);
79  v->Visit("body", &body);
80  v->Visit("span", &span);
81  }
82 
83  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
84  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
85  equal(body, other->body);
86  }
87 
88  void SHashReduce(SHashReducer hash_reduce) const {
89  hash_reduce.DefHash(var);
90  hash_reduce(value);
91  hash_reduce(body);
92  }
93 
94  static constexpr const char* _type_key = "tir.LetStmt";
96 };
97 
102 class LetStmt : public Stmt {
103  public:
104  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
105 
108 };
109 
120 class AttrStmtNode : public StmtNode {
121  public:
130 
132  v->Visit("node", &node);
133  v->Visit("attr_key", &attr_key);
134  v->Visit("value", &value);
135  v->Visit("body", &body);
136  v->Visit("span", &span);
137  }
138 
139  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
140  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
141  equal(value, other->value) && equal(body, other->body);
142  }
143 
144  void SHashReduce(SHashReducer hash_reduce) const {
145  hash_reduce(node);
146  hash_reduce(attr_key);
147  hash_reduce(value);
148  hash_reduce(body);
149  }
150 
151  static constexpr const char* _type_key = "tir.AttrStmt";
153 };
154 
159 class AttrStmt : public Stmt {
160  public:
161  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
162 
165 };
166 
170 class AssertStmtNode : public StmtNode {
171  public:
181 
183  v->Visit("condition", &condition);
184  v->Visit("message", &message);
185  v->Visit("body", &body);
186  v->Visit("span", &span);
187  }
188 
189  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
190  return equal(condition, other->condition) && equal(message, other->message) &&
191  equal(body, other->body);
192  }
193 
194  void SHashReduce(SHashReducer hash_reduce) const {
195  hash_reduce(condition);
196  hash_reduce(message);
197  hash_reduce(body);
198  }
199 
200  static constexpr const char* _type_key = "tir.AssertStmt";
202 };
203 
208 class AssertStmt : public Stmt {
209  public:
210  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
211 
214 };
215 
226 class BufferStoreNode : public StmtNode {
227  public:
236 
238  v->Visit("buffer", &buffer);
239  v->Visit("value", &value);
240  v->Visit("indices", &indices);
241  v->Visit("predicate", &predicate);
242  v->Visit("span", &span);
243  }
244 
245  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
246  return equal(buffer, other->buffer) && equal(value, other->value) &&
247  equal(indices, other->indices);
248  }
249 
250  void SHashReduce(SHashReducer hash_reduce) const {
251  hash_reduce(buffer);
252  hash_reduce(value);
253  hash_reduce(indices);
254  hash_reduce(predicate);
255  }
256 
257  static constexpr const char* _type_key = "tir.BufferStore";
259 };
260 
265 class BufferStore : public Stmt {
266  public:
267  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
268  Optional<PrimExpr> predicate = NullOpt, Span span = Span());
269 
272 };
273 
285 class BufferRealizeNode : public StmtNode {
286  public:
295 
297  v->Visit("buffer", &buffer);
298  v->Visit("bounds", &bounds);
299  v->Visit("condition", &condition);
300  v->Visit("body", &body);
301  v->Visit("span", &span);
302  }
303 
304  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
305  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
306  equal(condition, other->condition) && equal(body, other->body);
307  }
308 
309  void SHashReduce(SHashReducer hash_reduce) const {
310  hash_reduce(buffer);
311  hash_reduce(bounds);
312  hash_reduce(condition);
313  hash_reduce(body);
314  }
315 
316  BufferRealizeNode() = default;
318  Span span = Span())
320 
321  static constexpr const char* _type_key = "tir.BufferRealize";
323 };
324 
329 class BufferRealize : public Stmt {
330  public:
331  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
332  Span span = Span());
333 
336 };
337 
348 class ProducerStoreNode : public StmtNode {
349  public:
356 
358  v->Visit("producer", &producer);
359  v->Visit("value", &value);
360  v->Visit("indices", &indices);
361  v->Visit("span", &span);
362  }
363 
364  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
365  return equal(producer, other->producer) && equal(value, other->value) &&
366  equal(indices, other->indices);
367  }
368 
369  void SHashReduce(SHashReducer hash_reduce) const {
370  hash_reduce(producer);
371  hash_reduce(value);
372  hash_reduce(indices);
373  }
374 
375  static constexpr const char* _type_key = "tir.ProducerStore";
377 };
378 
383 class ProducerStore : public Stmt {
384  public:
385  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
386  Span span = Span());
387 
390 };
391 
404  public:
415 
417  v->Visit("producer", &producer);
418  v->Visit("bounds", &bounds);
419  v->Visit("condition", &condition);
420  v->Visit("body", &body);
421  v->Visit("storage_scope", &storage_scope);
422  v->Visit("span", &span);
423  }
424 
426  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
427  equal(condition, other->condition) && equal(body, other->body) &&
429  }
430 
431  void SHashReduce(SHashReducer hash_reduce) const {
432  hash_reduce(producer);
433  hash_reduce(bounds);
434  hash_reduce(condition);
435  hash_reduce(body);
436  hash_reduce(storage_scope);
437  }
438 
439  static constexpr const char* _type_key = "tir.ProducerRealize";
441 };
442 
447 class ProducerRealize : public Stmt {
448  public:
449  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
450  String storage_scope = "", Span span = Span());
451 
454 };
455 
459 class AllocateNode : public StmtNode {
460  public:
478 
480  v->Visit("buffer_var", &buffer_var);
481  v->Visit("dtype", &dtype);
482  v->Visit("extents", &extents);
483  v->Visit("condition", &condition);
484  v->Visit("body", &body);
485  v->Visit("annotations", &annotations);
486  v->Visit("span", &span);
487  }
488 
489  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
490  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
491  equal(extents, other->extents) && equal(condition, other->condition) &&
492  equal(body, other->body) && equal(annotations, other->annotations);
493  }
494 
495  void SHashReduce(SHashReducer hash_reduce) const {
496  hash_reduce.DefHash(buffer_var);
497  hash_reduce(dtype);
498  hash_reduce(extents);
499  hash_reduce(condition);
500  hash_reduce(body);
501  hash_reduce(annotations);
502  }
503 
516  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
517 
518  static constexpr const char* _type_key = "tir.Allocate";
519  static constexpr const bool _type_has_method_sequal_reduce = true;
520  static constexpr const bool _type_has_method_shash_reduce = true;
522 };
523 
528 class Allocate : public Stmt {
529  public:
530  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
531  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
532  Span span = Span());
533 
536 };
537 
541 class AllocateConstNode : public StmtNode {
542  public:
566 
568  v->Visit("buffer_var", &buffer_var);
569  v->Visit("data", &data);
570  v->Visit("irmod_storage_idx", &irmod_storage_idx);
571  v->Visit("dtype", &dtype);
572  v->Visit("extents", &extents);
573  v->Visit("body", &body);
574  v->Visit("annotations", &annotations);
575  v->Visit("span", &span);
576  }
577 
578  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
579  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
580  equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
581  equal(annotations, other->annotations);
582  }
583 
584  void SHashReduce(SHashReducer hash_reduce) const {
585  hash_reduce.DefHash(buffer_var);
586  hash_reduce(dtype);
587  hash_reduce(extents);
588  hash_reduce(body);
589  hash_reduce(annotations);
590  hash_reduce(data);
591  }
592 
605  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
606 
607  static constexpr const char* _type_key = "tir.AllocateConst";
608  static constexpr const bool _type_has_method_sequal_reduce = true;
609  static constexpr const bool _type_has_method_shash_reduce = true;
611 };
612 
617 class AllocateConst : public Stmt {
618  public:
619  /* The constructor to create a IRNode with constant data
620  * depending on the type of ObjectRef, it will either
621  * create AllocateConstNode with irmod_storage_idx or data
622  */
623  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
624  ObjectRef data_or_idx, Stmt body,
626  Span span = Span());
629 };
630 
632 class DeclBufferNode : public StmtNode {
633  public:
638 
640  v->Visit("buffer", &buffer);
641  v->Visit("body", &body);
642  v->Visit("span", &span);
643  }
644 
645  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
646  return equal(buffer, other->buffer) && equal(body, other->body);
647  }
648 
649  void SHashReduce(SHashReducer hash_reduce) const {
650  hash_reduce(buffer);
651  hash_reduce(body);
652  }
653 
654  static constexpr const char* _type_key = "tir.DeclBuffer";
656 };
657 
659 class DeclBuffer : public Stmt {
660  public:
661  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
664 };
665 
670 class SeqStmtNode : public StmtNode {
671  public:
674 
676  size_t size() const { return seq.size(); }
680  Stmt operator[](size_t index) const { return seq[index]; }
681 
683  v->Visit("seq", &seq);
684  v->Visit("span", &span);
685  }
686 
687  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
688  return equal(seq, other->seq);
689  }
690 
691  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
692 
693  static constexpr const char* _type_key = "tir.SeqStmt";
695 };
696 
703 class EvaluateNode : public StmtNode {
704  public:
707 
709  v->Visit("value", &value);
710  v->Visit("span", &span);
711  }
712 
713  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
714  return equal(value, other->value);
715  }
716 
717  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
718 
719  static constexpr const char* _type_key = "tir.Evaluate";
721 };
722 
727 class Evaluate : public Stmt {
728  public:
729  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
730 
731  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
732 
735 };
736 
738 class SeqStmt : public Stmt {
739  public:
745  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
746 
748  size_t size() const { return operator->()->size(); }
752  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
773  template <typename... Args>
774  static Stmt Flatten(Args&&... seq_args) {
775  Array<Stmt> seq;
776  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
777 
778  if (seq.empty()) {
779  return Evaluate(0);
780  } else if (seq.size() == 1) {
781  return seq[0];
782  }
783 
784  // If the argument is a single SeqStmt argument with no
785  // flattening or unwrapping required, then we may
786  // return the SeqStmt as-is.
787  if constexpr (sizeof...(seq_args) == 1) {
788  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
789  SeqStmt original = opt.value();
790  bool all_same = [&]() {
791  if (original->seq.size() != seq.size()) {
792  return false;
793  }
794  for (size_t i = 0; i < seq.size(); i++) {
795  if (!original->seq[i].same_as(seq[i])) {
796  return false;
797  }
798  }
799  return true;
800  }();
801  if (all_same) {
802  return original;
803  }
804  }
805  }
806 
807  return SeqStmt(seq);
808  }
810  class Flattener {
811  public:
812  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
813 
814  template <typename T>
815  static Optional<SeqStmt> AsSeqStmt(const T& t) {
816  if constexpr (std::is_same_v<T, SeqStmt>) {
817  return t;
818  } else if constexpr (!std::is_base_of_v<T, SeqStmt>) {
819  return NullOpt;
820  } else if (auto* ptr = t.template as<SeqStmtNode>()) {
821  return GetRef<SeqStmt>(ptr);
822  } else {
823  return NullOpt;
824  }
825  }
826 
827  template <typename T>
828  void operator()(size_t i, const T& stmt_or_seq) const {
829  if constexpr (std::is_base_of_v<ObjectRef, T>) {
830  // Early bail-out, applicable to any ObjectRef
831  if (!stmt_or_seq.defined()) {
832  return;
833  }
834  }
835 
836  if constexpr (std::is_same_v<T, SeqStmt>) {
837  // Static type-checking for a SeqStmt that could be flattened.
838  (*this)(0, stmt_or_seq->seq);
839  return;
840  }
841 
842  if constexpr (std::is_base_of_v<T, SeqStmt>) {
843  // Dynamic type-checking for a SeqStmt that could be
844  // flattened.
845  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
846  operator()(0, op->seq);
847  return;
848  }
849  }
850 
851  if constexpr (std::is_base_of_v<T, Evaluate>) {
852  // Evaluate(0) is used to represent a no-op, and may be
853  // generated by previous calls to SeqStmt::Flatten(). These
854  // should be removed to ensure that Flatten(a+b) is equivalent
855  // to Flatten(Flatten(a), Flatten(b)).
856  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
857  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
858  return;
859  }
860  }
861  }
862 
863  if constexpr (std::is_base_of_v<Stmt, T>) {
864  // Any other Stmt type just gets appended.
865  seq_->push_back(stmt_or_seq);
866  } else {
867  // Anything else is treated as an iterable of Stmt.
868  for (auto v : stmt_or_seq) {
869  this->operator()(0, v);
870  }
871  }
872  }
873 
874  private:
875  Array<Stmt>* seq_;
876  };
877 
880 };
881 
885 class IfThenElseNode : public StmtNode {
886  public:
893 
895  v->Visit("condition", &condition);
896  v->Visit("then_case", &then_case);
897  v->Visit("else_case", &else_case);
898  v->Visit("span", &span);
899  }
900 
901  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
902  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
903  equal(else_case, other->else_case);
904  }
905 
906  void SHashReduce(SHashReducer hash_reduce) const {
907  hash_reduce(condition);
908  hash_reduce(then_case);
909  hash_reduce(else_case);
910  }
911 
912  static constexpr const char* _type_key = "tir.IfThenElse";
914 };
915 
920 class IfThenElse : public Stmt {
921  public:
922  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
923  Span span = Span());
924 
927 };
928 
936 enum class ForKind : int {
938  kSerial = 0,
940  kParallel = 1,
945  kVectorized = 2,
947  kUnrolled = 3,
954  kThreadBinding = 4
955 };
956 
967 class ForNode : public StmtNode {
968  public:
993 
995  v->Visit("loop_var", &loop_var);
996  v->Visit("min", &min);
997  v->Visit("extent", &extent);
998  v->Visit("kind", &kind);
999  v->Visit("body", &body);
1000  v->Visit("thread_binding", &thread_binding);
1001  v->Visit("annotations", &annotations);
1002  v->Visit("span", &span);
1003  }
1004 
1005  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
1006  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
1007  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
1009  }
1010 
1011  void SHashReduce(SHashReducer hash_reduce) const {
1012  hash_reduce.DefHash(loop_var);
1013  hash_reduce(min);
1014  hash_reduce(extent);
1015  hash_reduce(kind);
1016  hash_reduce(body);
1017  hash_reduce(thread_binding);
1018  hash_reduce(annotations);
1019  }
1020 
1021  static constexpr const char* _type_key = "tir.For";
1023 };
1024 
1029 class For : public Stmt {
1030  public:
1031  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
1032  Optional<IterVar> thread_binding = NullOpt,
1033  Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());
1034 
1037 };
1038 
1049 class WhileNode : public StmtNode {
1050  public:
1055 
1057  v->Visit("condition", &condition);
1058  v->Visit("body", &body);
1059  v->Visit("span", &span);
1060  }
1061 
1062  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
1063  return equal(condition, other->condition) && equal(body, other->body);
1064  }
1065 
1066  void SHashReduce(SHashReducer hash_reduce) const {
1067  hash_reduce(condition);
1068  hash_reduce(body);
1069  }
1070 
1071  static constexpr const char* _type_key = "tir.While";
1073 };
1074 
1079 class While : public Stmt {
1080  public:
1081  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1082 
1085 };
1086 
1090 class PrefetchNode : public StmtNode {
1091  public:
1096 
1098  v->Visit("buffer", &buffer);
1099  v->Visit("bounds", &bounds);
1100  v->Visit("span", &span);
1101  }
1102 
1103  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1104  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1105  }
1106 
1107  void SHashReduce(SHashReducer hash_reduce) const {
1108  hash_reduce(buffer);
1109  hash_reduce(bounds);
1110  }
1111 
1112  PrefetchNode() = default;
1115 
1116  static constexpr const char* _type_key = "tir.Prefetch";
1118 };
1119 
1124 class Prefetch : public Stmt {
1125  public:
1126  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1127 
1130 };
1131 
1135 class BufferRegionNode : public Object {
1136  public:
1141 
1143  v->Visit("buffer", &buffer);
1144  v->Visit("region", &region);
1145  }
1146 
1147  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1148  return equal(buffer, other->buffer) && equal(region, other->region);
1149  }
1150 
1151  void SHashReduce(SHashReducer hash_reduce) const {
1152  hash_reduce(buffer);
1153  hash_reduce(region);
1154  }
1155 
1156  static constexpr const char* _type_key = "tir.BufferRegion";
1157  static constexpr const bool _type_has_method_sequal_reduce = true;
1158  static constexpr const bool _type_has_method_shash_reduce = true;
1160 };
1161 
1166 class BufferRegion : public ObjectRef {
1167  public:
1168  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1169 
1175  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1176 
1183  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1184 
1187 };
1188 
1199  public:
1204 
1206  v->Visit("buffer", &buffer);
1207  v->Visit("source", &source);
1208  }
1209 
1211  return equal(buffer, other->buffer) && equal(source, other->source);
1212  }
1213 
1214  void SHashReduce(SHashReducer hash_reduce) const {
1215  hash_reduce(buffer);
1216  hash_reduce(source);
1217  }
1218 
1219  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1220  static constexpr const bool _type_has_method_sequal_reduce = true;
1221  static constexpr const bool _type_has_method_shash_reduce = true;
1223 };
1224 
1230  public:
1231  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1232 
1235 };
1236 
1258 class BlockNode : public StmtNode {
1259  public:
1284 
1286  v->Visit("iter_vars", &iter_vars);
1287  v->Visit("reads", &reads);
1288  v->Visit("writes", &writes);
1289  v->Visit("name_hint", &name_hint);
1290  v->Visit("body", &body);
1291  v->Visit("init", &init);
1292  v->Visit("alloc_buffers", &alloc_buffers);
1293  v->Visit("match_buffers", &match_buffers);
1294  v->Visit("annotations", &annotations);
1295  }
1296 
1297  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1298  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1299  return equal.DefEqual(iter_vars, other->iter_vars) &&
1300  equal(alloc_buffers, other->alloc_buffers) &&
1301  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1302  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1303  equal(annotations, other->annotations);
1304  }
1305 
1306  void SHashReduce(SHashReducer hash_reduce) const {
1307  hash_reduce.DefHash(iter_vars);
1308  hash_reduce(alloc_buffers);
1309  hash_reduce(match_buffers);
1310  hash_reduce(reads);
1311  hash_reduce(writes);
1312  hash_reduce(body);
1313  hash_reduce(init);
1314  hash_reduce(annotations);
1315  }
1316 
1317  static constexpr const char* _type_key = "tir.Block";
1319 };
1320 
1325 class Block : public Stmt {
1326  public:
1327  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1328  Array<BufferRegion> writes, String name_hint, Stmt body,
1329  Optional<Stmt> init = NullOpt,
1330  Array<Buffer> alloc_buffers = Array<Buffer>(),
1333  Span span = Span());
1334 
1337 };
1338 
1342 class BlockRealizeNode : public StmtNode {
1343  public:
1353 
1355  v->Visit("iter_values", &iter_values);
1356  v->Visit("predicate", &predicate);
1357  v->Visit("block", &block);
1358  }
1359 
1360  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1361  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1362  equal(block, other->block);
1363  }
1364 
1365  void SHashReduce(SHashReducer hash_reduce) const {
1366  hash_reduce(iter_values);
1367  hash_reduce(predicate);
1368  hash_reduce(block);
1369  }
1370 
1371  static constexpr const char* _type_key = "tir.BlockRealize";
1373 };
1374 
1379 class BlockRealize : public Stmt {
1380  public:
1381  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1382  Span span = Span());
1383 
1386 };
1387 
1389 namespace attr {
1390 // The above attr does not pass to ir stage.
1392 constexpr const char* thread_extent = "thread_extent";
1394 constexpr const char* virtual_thread = "virtual_thread";
1396 constexpr const char* coproc_scope = "coproc_scope";
1401 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1403 constexpr const char* volatile_scope = "volatile_scope";
1409 constexpr const char* extern_scope = "extern_scope";
1414 constexpr const char* compute_scope = "compute_scope";
1416 constexpr const char* storage_alignment = "storage_alignment";
1418 constexpr const char* realize_scope = "realize_scope";
1420 constexpr const char* device_id = "device_id";
1422 constexpr const char* device_type = "device_type";
1424 constexpr const char* loop_scope = "loop_scope";
1426 constexpr const char* reduce_scope = "reduce_scope";
1428 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1430 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1432 constexpr const char* pragma_scope_prefix = "pragma_";
1434 constexpr const char* pragma_import_c = "pragma_import_c";
1436 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1438 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1443 constexpr const char* prefetch_scope = "prefetch_scope";
1450 constexpr const char* layout_transforms = "layout_transforms";
1458 constexpr const char* axis_separators = "axis_separators";
1462 constexpr const char* double_buffer_scope = "double_buffer_scope";
1466 constexpr const char* double_buffer_write = "double_buffer_write";
1468 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1470 constexpr const char* scan_update_scope = "scan_update_scope";
1472 constexpr const char* scan_init_scope = "scan_init_scope";
1479 constexpr const char* buffer_dim_align = "buffer_dim_align";
1481 constexpr const char* buffer_bound = "buffer_bound";
1491 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1492 // Pipeline related attributes
1494 constexpr const char* channel_read_scope = "channel_read_scope";
1496 constexpr const char* channel_read_advance = "channel_read_advance";
1498 constexpr const char* channel_write_scope = "channel_write_scope";
1500 constexpr const char* channel_write_advance = "channel_write_advance";
1502 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1504 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1505 
1509 constexpr const char* device_scope = "device_scope";
1510 
1514 constexpr const char* async_scope = "async_scope";
1515 
1533 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1534 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1535 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1536 
1540 constexpr const char* fragment_shape = "fragment_shape";
1541 
1545 constexpr const char* fragment_layout = "fragment_layout";
1546 
1550 constexpr const char* hand_threaded = "hand_threaded";
1551 
1559 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1560 
1564 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1565 
1567 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1568 
1570 constexpr const char* software_pipeline_order = "software_pipeline_order";
1571 
1576 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1577 
1579 constexpr const char* layout_free_buffers = "layout_free_buffers";
1580 
1582 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1583 
1585 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1586 
1591 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1592 
1595  "meta_schedule.thread_extent_low_inclusive";
1596 
1599  "meta_schedule.thread_extent_high_inclusive";
1600 
1603  "meta_schedule.random_compute_producer";
1604 
1606 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1607 
1609 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1610 
1612 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1613 
1615 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1616 
1618 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1619 
1621 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1625 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1626 
1630 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1631 
1633 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1634 
1641 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1642 
1644 constexpr const int meta_schedule_cache_type_read = 0;
1645 
1647 constexpr const int meta_schedule_cache_type_write = 1;
1648 
1650 constexpr const char* auto_copy = "auto_copy";
1651 
1653 constexpr const char* local_stage = "local_stage";
1654 
1656 constexpr const char* vector_bytes = "vector_bytes";
1657 
1662 constexpr const char* warp_execution = "warp_execution";
1663 
1665 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1666 
1670 constexpr const char* explicit_read_region = "explicit_read_region";
1671 
1675 constexpr const char* explicit_write_region = "explicit_write_region";
1676 
1682 inline bool IsPragmaKey(const std::string& attr_key) {
1683  return attr_key.compare(0, 7, "pragma_") == 0;
1684 }
1685 
1686 } // namespace attr
1693 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1694 
1695 // overload printing of for type.
1696 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1697 
1698 // inline implementations
1699 inline const char* ForKind2String(ForKind t) {
1700  switch (t) {
1701  case ForKind::kSerial:
1702  return "serial";
1703  case ForKind::kParallel:
1704  return "parallel";
1705  case ForKind::kVectorized:
1706  return "vectorized";
1707  case ForKind::kUnrolled:
1708  return "unroll";
1710  return "thread_binding";
1711  }
1712  LOG(FATAL) << "Unknown ForKind" << t;
1713 }
1714 
1715 } // namespace tir
1716 } // namespace tvm
1717 #endif // TVM_TIR_STMT_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool empty() const
Definition: array.h:432
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
const Object * operator->() const
Definition: object.h:556
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Allocate a buffer that can be used in body.
Definition: stmt.h:541
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:565
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:552
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:556
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:609
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:584
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:558
DataType dtype
The type of the buffer.
Definition: stmt.h:554
Var buffer_var
The buffer variable.
Definition: stmt.h:544
static constexpr const char * _type_key
Definition: stmt.h:607
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:578
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:598
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:608
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:567
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:547
Managed reference to AllocateConstNode.
Definition: stmt.h:617
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode)
AllocateConst(Var buffer_var, DataType dtype, Array< PrimExpr > extents, ObjectRef data_or_idx, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode)
Allocate a buffer that can be used in body.
Definition: stmt.h:459
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:489
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:466
static int64_t ConstantAllocationSize(const Array< PrimExpr > &extents)
If the buffer size is constant, return the size. Otherwise return 0.
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:477
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:468
Stmt body
The body to be executed.
Definition: stmt.h:470
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode)
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:509
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:520
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:519
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:479
DataType dtype
The type of the buffer.
Definition: stmt.h:464
static constexpr const char * _type_key
Definition: stmt.h:518
Var buffer_var
The buffer variable.
Definition: stmt.h:462
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:495
Managed reference to AllocateNode.
Definition: stmt.h:528
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode)
Allocate(Var buffer_var, DataType dtype, Array< PrimExpr > extents, PrimExpr condition, Stmt body, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:170
PrimExpr condition
Condition to be checked.
Definition: stmt.h:173
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:175
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:189
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:200
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:194
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:182
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:180
Managed reference to AssertStmtNode.
Definition: stmt.h:208
TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:120
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:127
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:144
Stmt body
The body statement to be executed.
Definition: stmt.h:129
String attr_key
the type key of the attribute
Definition: stmt.h:125
ObjectRef node
this is attribute about certain node
Definition: stmt.h:123
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:131
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode)
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:139
static constexpr const char * _type_key
Definition: stmt.h:151
Managed reference to AttrStmtNode.
Definition: stmt.h:159
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
A block is a basic schedule unit in TIR.
Definition: stmt.h:1258
static constexpr const char * _type_key
Definition: stmt.h:1317
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1263
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1281
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1261
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1265
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1283
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:1277
String name_hint
The name_hint of the block.
Definition: stmt.h:1267
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1285
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1279
Stmt body
The body of the block.
Definition: stmt.h:1269
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1306
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1297
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode)
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1342
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1365
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1360
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:1371
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1345
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:1350
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1354
Block block
The block to be realized.
Definition: stmt.h:1352
Managed reference to BlockRealizeNode.
Definition: stmt.h:1379
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode)
BlockRealize(Array< PrimExpr > iter_values, PrimExpr predicate, Block block, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode)
Managed reference to BlockNode.
Definition: stmt.h:1325
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode)
Block(Array< IterVar > iter_vars, Array< BufferRegion > reads, Array< BufferRegion > writes, String name_hint, Stmt body, Optional< Stmt > init=NullOpt, Array< Buffer > alloc_buffers=Array< Buffer >(), Array< MatchBufferRegion > match_buffers=Array< MatchBufferRegion >(), Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode)
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:285
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:296
Buffer buffer
The buffer variable.
Definition: stmt.h:288
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:304
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:317
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:292
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:309
Stmt body
The body of realization.
Definition: stmt.h:294
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:290
static constexpr const char * _type_key
Definition: stmt.h:321
Managed reference to BufferRealizeNode.
Definition: stmt.h:329
BufferRealize(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1135
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1138
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1147
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1158
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1142
static constexpr const char * _type_key
Definition: stmt.h:1156
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1157
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1140
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1151
Managed reference to BufferRegionNode.
Definition: stmt.h:1166
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
static BufferRegion FromPoint(Buffer buffer, Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
BufferRegion(Buffer buffer, Array< Range > region)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:226
Buffer buffer
The buffer variable.
Definition: stmt.h:229
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:233
Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:235
PrimExpr value
The value to be stored.
Definition: stmt.h:231
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:250
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:237
static constexpr const char * _type_key
Definition: stmt.h:257
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:245
Managed reference to BufferStoreNode.
Definition: stmt.h:265
BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices, Optional< PrimExpr > predicate=NullOpt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
Managed reference to DataProducerNode.
Definition: buffer.h:313
Declare a buffer that can be used in the body.
Definition: stmt.h:632
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:649
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:639
static constexpr const char * _type_key
Definition: stmt.h:654
Buffer buffer
The buffer being declared.
Definition: stmt.h:635
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode)
Stmt body
The body to be executed.
Definition: stmt.h:637
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:645
Managed reference to DeclBufferNode.
Definition: stmt.h:659
DeclBuffer(Buffer buffer, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:703
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:717
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:708
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:719
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:706
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:713
Managed reference to EvaluateNode.
Definition: stmt.h:727
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
Evaluate(int value, Span span=Span())
Definition: stmt.h:731
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode)
Evaluate(PrimExpr value, Span span=Span())
A for loop, with possible type annotations.
Definition: stmt.h:967
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:994
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:983
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:972
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1011
ForKind kind
The kind of the for loop.
Definition: stmt.h:976
static constexpr const char * _type_key
Definition: stmt.h:1021
Var loop_var
The loop variable.
Definition: stmt.h:970
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:974
Stmt body
The body of the for loop.
Definition: stmt.h:978
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode)
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:1005
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:992
Managed reference to ForNode.
Definition: stmt.h:1029
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional< IterVar > thread_binding=NullOpt, Map< String, ObjectRef > annotations=Map< String, ObjectRef >(), Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode)
IfThenElse statement.
Definition: stmt.h:885
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:894
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:906
PrimExpr condition
The condition.
Definition: stmt.h:888
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:892
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:901
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode)
static constexpr const char * _type_key
Definition: stmt.h:912
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:890
Managed reference to IfThenElseNode.
Definition: stmt.h:920
IfThenElse(PrimExpr condition, Stmt then_case, Optional< Stmt > else_case=NullOpt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode)
Let binding, bind var to value, then run body.
Definition: stmt.h:67
PrimExpr value
The value to be bound.
Definition: stmt.h:72
static constexpr const char * _type_key
Definition: stmt.h:94
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:76
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:83
Stmt body
The body block.
Definition: stmt.h:74
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode)
Var var
The variable.
Definition: stmt.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:88
Managed reference to LetStmtNode.
Definition: stmt.h:102
TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode)
LetStmt(Var var, PrimExpr value, Stmt body, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1198
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object)
static constexpr const char * _type_key
Definition: stmt.h:1219
Buffer buffer
The target buffer.
Definition: stmt.h:1201
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:1221
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1205
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1210
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:1220
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1214
BufferRegion source
The source buffer region.
Definition: stmt.h:1203
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1229
MatchBufferRegion(Buffer buffer, BufferRegion source)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
A prefetch hint for a buffer.
Definition: stmt.h:1090
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode)
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1097
static constexpr const char * _type_key
Definition: stmt.h:1116
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1113
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1095
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1107
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1093
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1103
Managed reference to PrefetchNode.
Definition: stmt.h:1124
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode)
Prefetch(Buffer buffer, Array< Range > bounds, Span span=Span())
Annotate the bounds where the data produced by the producer need to be written and read in body....
Definition: stmt.h:403
Stmt body
The body of realization.
Definition: stmt.h:412
DataProducer producer
The producer that produces the data.
Definition: stmt.h:406
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:410
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:431
static constexpr const char * _type_key
Definition: stmt.h:439
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:416
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:414
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:425
Region bounds
Bounds to be realized.
Definition: stmt.h:408
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode)
Managed reference to ProducerRealizeNode.
Definition: stmt.h:447
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode)
ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope="", Span span=Span())
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:348
PrimExpr value
The value to be stored.
Definition: stmt.h:353
DataProducer producer
The producer to store the results into.
Definition: stmt.h:351
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:369
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode)
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:355
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:357
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:364
static constexpr const char * _type_key
Definition: stmt.h:375
Managed reference to ProducerStoreNode.
Definition: stmt.h:383
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode)
TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode)
ProducerStore(DataProducer producer, PrimExpr value, Array< PrimExpr > indices, Span span=Span())
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:670
size_t size() const
Definition: stmt.h:676
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:673
static constexpr const char * _type_key
Definition: stmt.h:693
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:682
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:687
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:691
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:680
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:810
Flattener(Array< Stmt > *seq)
Definition: stmt.h:812
static Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:815
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:828
Sequence statement.
Definition: stmt.h:738
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode)
size_t size() const
Definition: stmt.h:748
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:752
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:774
SeqStmt(Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
Base node of all statements.
Definition: stmt.h:38
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:54
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:52
StmtNode(Span span)
Definition: stmt.h:47
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:53
static constexpr const char * _type_key
Definition: stmt.h:51
Container of all statements.
Definition: stmt.h:59
TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode)
a named variable in TIR
Definition: var.h:89
A While loop.
Definition: stmt.h:1049
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1056
static constexpr const char * _type_key
Definition: stmt.h:1071
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:1054
PrimExpr condition
The termination condition.
Definition: stmt.h:1052
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1066
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:1062
Managed reference to WhileNode.
Definition: stmt.h:1079
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
tvm::Span Span
Definition: base.h:65
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1414
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt....
Definition: stmt.h:1491
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1570
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1612
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn't need syncs inserted.
Definition: stmt.h:1550
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim,...
Definition: stmt.h:1479
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1496
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1403
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1618
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1502
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1582
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1434
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1430
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1585
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1567
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1602
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.
Definition: stmt.h:1662
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1509
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1682
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1392
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing.
Definition: stmt.h:1559
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1633
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1394
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1409
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1641
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1426
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1615
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1498
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1650
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1468
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1533
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1420
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1594
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1621
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1656
constexpr const char * device_type
The device type.
Definition: stmt.h:1422
constexpr const char * explicit_write_region
Mark that a block has an explicitly specified write region. This is used to override the default writ...
Definition: stmt.h:1675
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1644
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1576
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1625
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1591
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1470
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1428
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1424
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1462
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1540
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1438
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1545
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1579
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1534
constexpr const char * explicit_read_region
Mark that a block has an explicitly specified read region. This is used to override the default read ...
Definition: stmt.h:1670
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1647
constexpr const char * coproc_scope
Mark region is processed by a co-processor.
Definition: stmt.h:1396
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1481
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1443
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1535
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1450
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1418
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1494
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1598
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1500
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.
Definition: stmt.h:1401
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1665
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1564
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1504
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1436
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1432
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1472
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1630
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1416
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1653
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1609
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1466
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1514
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1606
const char * ForKind2String(ForKind t)
Definition: stmt.h:1699
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:936
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
@ kVectorized
The loop is vectorized.
Definition: var.h:244
@ kUnrolled
The execution is unrolled.
Definition: var.h:240
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
TIR expressions.