tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
49  static constexpr const char* _type_key = "tir.Stmt";
50  static constexpr const bool _type_has_method_sequal_reduce = true;
51  static constexpr const bool _type_has_method_shash_reduce = true;
52  static constexpr const uint32_t _type_child_slots = 15;
54 };
55 
57 class Stmt : public ObjectRef {
58  public:
60 };
61 
65 class LetStmtNode : public StmtNode {
66  public:
73 
75  v->Visit("var", &var);
76  v->Visit("value", &value);
77  v->Visit("body", &body);
78  v->Visit("span", &span);
79  }
80 
81  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
82  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
83  equal(body, other->body);
84  }
85 
86  void SHashReduce(SHashReducer hash_reduce) const {
87  hash_reduce.DefHash(var);
88  hash_reduce(value);
89  hash_reduce(body);
90  }
91 
92  static constexpr const char* _type_key = "tir.LetStmt";
94 };
95 
100 class LetStmt : public Stmt {
101  public:
102  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
103 
106 };
107 
118 class AttrStmtNode : public StmtNode {
119  public:
128 
130  v->Visit("node", &node);
131  v->Visit("attr_key", &attr_key);
132  v->Visit("value", &value);
133  v->Visit("body", &body);
134  v->Visit("span", &span);
135  }
136 
137  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
138  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
139  equal(value, other->value) && equal(body, other->body);
140  }
141 
142  void SHashReduce(SHashReducer hash_reduce) const {
143  hash_reduce(node);
144  hash_reduce(attr_key);
145  hash_reduce(value);
146  hash_reduce(body);
147  }
148 
149  static constexpr const char* _type_key = "tir.AttrStmt";
151 };
152 
157 class AttrStmt : public Stmt {
158  public:
159  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
160 
163 };
164 
168 class AssertStmtNode : public StmtNode {
169  public:
179 
181  v->Visit("condition", &condition);
182  v->Visit("message", &message);
183  v->Visit("body", &body);
184  v->Visit("span", &span);
185  }
186 
187  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
188  return equal(condition, other->condition) && equal(message, other->message) &&
189  equal(body, other->body);
190  }
191 
192  void SHashReduce(SHashReducer hash_reduce) const {
193  hash_reduce(condition);
194  hash_reduce(message);
195  hash_reduce(body);
196  }
197 
198  static constexpr const char* _type_key = "tir.AssertStmt";
200 };
201 
206 class AssertStmt : public Stmt {
207  public:
208  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
209 
212 };
213 
232 class StoreNode : public StmtNode {
233  public:
242 
244  v->Visit("buffer_var", &buffer_var);
245  v->Visit("value", &value);
246  v->Visit("index", &index);
247  v->Visit("predicate", &predicate);
248  v->Visit("span", &span);
249  }
250 
251  bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
252  return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
253  equal(index, other->index) && equal(predicate, other->predicate);
254  }
255 
256  void SHashReduce(SHashReducer hash_reduce) const {
257  hash_reduce(buffer_var);
258  hash_reduce(value);
259  hash_reduce(index);
260  hash_reduce(predicate);
261  }
262 
263  static constexpr const char* _type_key = "tir.Store";
265 };
266 
271 class Store : public Stmt {
272  public:
273  TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
274  Span span = Span());
275 
278 };
279 
290 class BufferStoreNode : public StmtNode {
291  public:
298 
300  v->Visit("buffer", &buffer);
301  v->Visit("value", &value);
302  v->Visit("indices", &indices);
303  v->Visit("span", &span);
304  }
305 
306  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
307  return equal(buffer, other->buffer) && equal(value, other->value) &&
308  equal(indices, other->indices);
309  }
310 
311  void SHashReduce(SHashReducer hash_reduce) const {
312  hash_reduce(buffer);
313  hash_reduce(value);
314  hash_reduce(indices);
315  }
316 
317  static constexpr const char* _type_key = "tir.BufferStore";
319 };
320 
325 class BufferStore : public Stmt {
326  public:
327  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
328  Span span = Span());
329 
332 };
333 
345 class BufferRealizeNode : public StmtNode {
346  public:
355 
357  v->Visit("buffer", &buffer);
358  v->Visit("bounds", &bounds);
359  v->Visit("condition", &condition);
360  v->Visit("body", &body);
361  v->Visit("span", &span);
362  }
363 
364  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
365  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
366  equal(condition, other->condition) && equal(body, other->body);
367  }
368 
369  void SHashReduce(SHashReducer hash_reduce) const {
370  hash_reduce(buffer);
371  hash_reduce(bounds);
372  hash_reduce(condition);
373  hash_reduce(body);
374  }
375 
376  BufferRealizeNode() = default;
377  BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
378  Span span = Span())
379  : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {}
380 
381  static constexpr const char* _type_key = "tir.BufferRealize";
383 };
384 
389 class BufferRealize : public Stmt {
390  public:
391  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
392  Span span = Span());
393 
396 };
397 
408 class ProducerStoreNode : public StmtNode {
409  public:
416 
418  v->Visit("producer", &producer);
419  v->Visit("value", &value);
420  v->Visit("indices", &indices);
421  v->Visit("span", &span);
422  }
423 
424  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
425  return equal(producer, other->producer) && equal(value, other->value) &&
426  equal(indices, other->indices);
427  }
428 
429  void SHashReduce(SHashReducer hash_reduce) const {
430  hash_reduce(producer);
431  hash_reduce(value);
432  hash_reduce(indices);
433  }
434 
435  static constexpr const char* _type_key = "tir.ProducerStore";
437 };
438 
443 class ProducerStore : public Stmt {
444  public:
445  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
446  Span span = Span());
447 
450 };
451 
464  public:
475 
477  v->Visit("producer", &producer);
478  v->Visit("bounds", &bounds);
479  v->Visit("condition", &condition);
480  v->Visit("body", &body);
481  v->Visit("storage_scope", &storage_scope);
482  v->Visit("span", &span);
483  }
484 
486  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
487  equal(condition, other->condition) && equal(body, other->body) &&
488  equal(storage_scope, other->storage_scope);
489  }
490 
491  void SHashReduce(SHashReducer hash_reduce) const {
492  hash_reduce(producer);
493  hash_reduce(bounds);
494  hash_reduce(condition);
495  hash_reduce(body);
496  hash_reduce(storage_scope);
497  }
498 
499  static constexpr const char* _type_key = "tir.ProducerRealize";
501 };
502 
507 class ProducerRealize : public Stmt {
508  public:
509  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
510  String storage_scope = "", Span span = Span());
511 
514 };
515 
519 class AllocateNode : public StmtNode {
520  public:
538 
540  v->Visit("buffer_var", &buffer_var);
541  v->Visit("dtype", &dtype);
542  v->Visit("extents", &extents);
543  v->Visit("condition", &condition);
544  v->Visit("body", &body);
545  v->Visit("annotations", &annotations);
546  v->Visit("span", &span);
547  }
548 
549  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
550  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
551  equal(extents, other->extents) && equal(condition, other->condition) &&
552  equal(body, other->body) && equal(annotations, other->annotations);
553  }
554 
555  void SHashReduce(SHashReducer hash_reduce) const {
556  hash_reduce.DefHash(buffer_var);
557  hash_reduce(dtype);
558  hash_reduce(extents);
559  hash_reduce(condition);
560  hash_reduce(body);
561  hash_reduce(annotations);
562  }
563 
569  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
576  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
577 
578  static constexpr const char* _type_key = "tir.Allocate";
579  static constexpr const bool _type_has_method_sequal_reduce = true;
580  static constexpr const bool _type_has_method_shash_reduce = true;
582 };
583 
588 class Allocate : public Stmt {
589  public:
590  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
591  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
592  Span span = Span());
593 
596 };
597 
601 class AllocateConstNode : public StmtNode {
602  public:
626 
628  v->Visit("buffer_var", &buffer_var);
629  v->Visit("data", &data);
630  v->Visit("irmod_storage_idx", &irmod_storage_idx);
631  v->Visit("dtype", &dtype);
632  v->Visit("extents", &extents);
633  v->Visit("body", &body);
634  v->Visit("annotations", &annotations);
635  v->Visit("span", &span);
636  }
637 
638  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
639  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
640  equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
641  equal(annotations, other->annotations);
642  }
643 
644  void SHashReduce(SHashReducer hash_reduce) const {
645  hash_reduce.DefHash(buffer_var);
646  hash_reduce(dtype);
647  hash_reduce(extents);
648  hash_reduce(body);
649  hash_reduce(annotations);
650  hash_reduce(data);
651  }
652 
658  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
665  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
666 
667  static constexpr const char* _type_key = "tir.AllocateConst";
668  static constexpr const bool _type_has_method_sequal_reduce = true;
669  static constexpr const bool _type_has_method_shash_reduce = true;
671 };
672 
677 class AllocateConst : public Stmt {
678  public:
679  /* The constructor to create a IRNode with constant data
680  * depending on the type of ObjectRef, it will either
681  * create AllocateConstNode with irmod_storage_idx or data
682  */
683  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
684  ObjectRef data_or_idx, Stmt body,
686  Span span = Span());
689 };
690 
692 class DeclBufferNode : public StmtNode {
693  public:
698 
700  v->Visit("buffer", &buffer);
701  v->Visit("body", &body);
702  v->Visit("span", &span);
703  }
704 
705  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
706  return equal(buffer, other->buffer) && equal(body, other->body);
707  }
708 
709  void SHashReduce(SHashReducer hash_reduce) const {
710  hash_reduce(buffer);
711  hash_reduce(body);
712  }
713 
714  static constexpr const char* _type_key = "tir.DeclBuffer";
716 };
717 
719 class DeclBuffer : public Stmt {
720  public:
721  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
724 };
725 
730 class SeqStmtNode : public StmtNode {
731  public:
734 
736  size_t size() const { return seq.size(); }
740  Stmt operator[](size_t index) const { return seq[index]; }
741 
743  v->Visit("seq", &seq);
744  v->Visit("span", &span);
745  }
746 
747  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
748  return equal(seq, other->seq);
749  }
750 
751  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
752 
753  static constexpr const char* _type_key = "tir.SeqStmt";
755 };
756 
758 class SeqStmt : public Stmt {
759  public:
765  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
766 
768  size_t size() const { return operator->()->size(); }
772  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
789  template <typename... Args>
790  static Stmt Flatten(Args&&... seq_args) {
791  Array<Stmt> seq;
792  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
793  if (seq.size() == 1) return seq[0];
794  return SeqStmt(seq);
795  }
797  class Flattener {
798  public:
799  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
800 
801  void operator()(size_t i, const Stmt& stmt) const {
802  if (!stmt.defined()) return;
803  if (auto* op = stmt.as<SeqStmtNode>()) {
804  operator()(0, op->seq);
805  } else {
806  seq_->push_back(stmt);
807  }
808  }
809 
810  template <typename T>
811  void operator()(size_t i, const T& seq) const {
812  for (auto v : seq) {
813  this->operator()(0, v);
814  }
815  }
816 
817  private:
818  Array<Stmt>* seq_;
819  };
820 
823 };
824 
828 class IfThenElseNode : public StmtNode {
829  public:
836 
838  v->Visit("condition", &condition);
839  v->Visit("then_case", &then_case);
840  v->Visit("else_case", &else_case);
841  v->Visit("span", &span);
842  }
843 
844  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
845  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
846  equal(else_case, other->else_case);
847  }
848 
849  void SHashReduce(SHashReducer hash_reduce) const {
850  hash_reduce(condition);
851  hash_reduce(then_case);
852  hash_reduce(else_case);
853  }
854 
855  static constexpr const char* _type_key = "tir.IfThenElse";
857 };
858 
863 class IfThenElse : public Stmt {
864  public:
865  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
866  Span span = Span());
867 
870 };
871 
878 class EvaluateNode : public StmtNode {
879  public:
882 
884  v->Visit("value", &value);
885  v->Visit("span", &span);
886  }
887 
888  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
889  return equal(value, other->value);
890  }
891 
892  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
893 
894  static constexpr const char* _type_key = "tir.Evaluate";
896 };
897 
902 class Evaluate : public Stmt {
903  public:
904  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
905 
906  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
907 
910 };
911 
919 enum class ForKind : int {
921  kSerial = 0,
923  kParallel = 1,
928  kVectorized = 2,
930  kUnrolled = 3,
937  kThreadBinding = 4
938 };
939 
950 class ForNode : public StmtNode {
951  public:
976 
978  v->Visit("loop_var", &loop_var);
979  v->Visit("min", &min);
980  v->Visit("extent", &extent);
981  v->Visit("kind", &kind);
982  v->Visit("body", &body);
983  v->Visit("thread_binding", &thread_binding);
984  v->Visit("annotations", &annotations);
985  v->Visit("span", &span);
986  }
987 
988  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
989  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
990  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
991  equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
992  }
993 
994  void SHashReduce(SHashReducer hash_reduce) const {
995  hash_reduce.DefHash(loop_var);
996  hash_reduce(min);
997  hash_reduce(extent);
998  hash_reduce(kind);
999  hash_reduce(body);
1000  hash_reduce(thread_binding);
1001  hash_reduce(annotations);
1002  }
1003 
1004  static constexpr const char* _type_key = "tir.For";
1006 };
1007 
1012 class For : public Stmt {
1013  public:
1014  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
1015  Optional<IterVar> thread_binding = NullOpt,
1017 
1020 };
1021 
1032 class WhileNode : public StmtNode {
1033  public:
1038 
1040  v->Visit("condition", &condition);
1041  v->Visit("body", &body);
1042  v->Visit("span", &span);
1043  }
1044 
1045  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
1046  return equal(condition, other->condition) && equal(body, other->body);
1047  }
1048 
1049  void SHashReduce(SHashReducer hash_reduce) const {
1050  hash_reduce(condition);
1051  hash_reduce(body);
1052  }
1053 
1054  static constexpr const char* _type_key = "tir.While";
1056 };
1057 
1062 class While : public Stmt {
1063  public:
1064  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1065 
1068 };
1069 
1073 class PrefetchNode : public StmtNode {
1074  public:
1079 
1081  v->Visit("buffer", &buffer);
1082  v->Visit("bounds", &bounds);
1083  v->Visit("span", &span);
1084  }
1085 
1086  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1087  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1088  }
1089 
1090  void SHashReduce(SHashReducer hash_reduce) const {
1091  hash_reduce(buffer);
1092  hash_reduce(bounds);
1093  }
1094 
1095  PrefetchNode() = default;
1097  : StmtNode(span), buffer(buffer), bounds(bounds) {}
1098 
1099  static constexpr const char* _type_key = "tir.Prefetch";
1101 };
1102 
1107 class Prefetch : public Stmt {
1108  public:
1109  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1110 
1113 };
1114 
1118 class BufferRegionNode : public Object {
1119  public:
1124 
1126  v->Visit("buffer", &buffer);
1127  v->Visit("region", &region);
1128  }
1129 
1130  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1131  return equal(buffer, other->buffer) && equal(region, other->region);
1132  }
1133 
1134  void SHashReduce(SHashReducer hash_reduce) const {
1135  hash_reduce(buffer);
1136  hash_reduce(region);
1137  }
1138 
1139  static constexpr const char* _type_key = "tir.BufferRegion";
1140  static constexpr const bool _type_has_method_sequal_reduce = true;
1141  static constexpr const bool _type_has_method_shash_reduce = true;
1143 };
1144 
1149 class BufferRegion : public ObjectRef {
1150  public:
1151  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1152 
1158  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1159 
1166  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1167 
1170 };
1171 
1182  public:
1187 
1189  v->Visit("buffer", &buffer);
1190  v->Visit("source", &source);
1191  }
1192 
1194  return equal(buffer, other->buffer) && equal(source, other->source);
1195  }
1196 
1197  void SHashReduce(SHashReducer hash_reduce) const {
1198  hash_reduce(buffer);
1199  hash_reduce(source);
1200  }
1201 
1202  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1203  static constexpr const bool _type_has_method_sequal_reduce = true;
1204  static constexpr const bool _type_has_method_shash_reduce = true;
1206 };
1207 
1213  public:
1214  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1215 
1218 };
1219 
1241 class BlockNode : public StmtNode {
1242  public:
1267 
1269  v->Visit("iter_vars", &iter_vars);
1270  v->Visit("reads", &reads);
1271  v->Visit("writes", &writes);
1272  v->Visit("name_hint", &name_hint);
1273  v->Visit("body", &body);
1274  v->Visit("init", &init);
1275  v->Visit("alloc_buffers", &alloc_buffers);
1276  v->Visit("match_buffers", &match_buffers);
1277  v->Visit("annotations", &annotations);
1278  }
1279 
1280  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1281  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1282  return equal.DefEqual(iter_vars, other->iter_vars) &&
1283  equal(alloc_buffers, other->alloc_buffers) &&
1284  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1285  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1286  equal(annotations, other->annotations);
1287  }
1288 
1289  void SHashReduce(SHashReducer hash_reduce) const {
1290  hash_reduce.DefHash(iter_vars);
1291  hash_reduce(alloc_buffers);
1292  hash_reduce(match_buffers);
1293  hash_reduce(reads);
1294  hash_reduce(writes);
1295  hash_reduce(body);
1296  hash_reduce(init);
1297  hash_reduce(annotations);
1298  }
1299 
1300  static constexpr const char* _type_key = "tir.Block";
1302 };
1303 
1308 class Block : public Stmt {
1309  public:
1310  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1311  Array<BufferRegion> writes, String name_hint, Stmt body,
1312  Optional<Stmt> init = NullOpt,
1313  Array<Buffer> alloc_buffers = Array<Buffer>(),
1316  Span span = Span());
1317 
1320 };
1321 
1325 class BlockRealizeNode : public StmtNode {
1326  public:
1336 
1338  v->Visit("iter_values", &iter_values);
1339  v->Visit("predicate", &predicate);
1340  v->Visit("block", &block);
1341  }
1342 
1343  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1344  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1345  equal(block, other->block);
1346  }
1347 
1348  void SHashReduce(SHashReducer hash_reduce) const {
1349  hash_reduce(iter_values);
1350  hash_reduce(predicate);
1351  hash_reduce(block);
1352  }
1353 
1354  static constexpr const char* _type_key = "tir.BlockRealize";
1356 };
1357 
1362 class BlockRealize : public Stmt {
1363  public:
1364  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1365  Span span = Span());
1366 
1369 };
1370 
1372 namespace attr {
1373 // The above attr does not pass to ir stage.
1375 constexpr const char* thread_extent = "thread_extent";
1377 constexpr const char* virtual_thread = "virtual_thread";
1379 constexpr const char* coproc_scope = "coproc_scope";
1384 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1386 constexpr const char* volatile_scope = "volatile_scope";
1392 constexpr const char* extern_scope = "extern_scope";
1397 constexpr const char* compute_scope = "compute_scope";
1399 constexpr const char* storage_alignment = "storage_alignment";
1401 constexpr const char* realize_scope = "realize_scope";
1403 constexpr const char* device_id = "device_id";
1405 constexpr const char* device_type = "device_type";
1407 constexpr const char* loop_scope = "loop_scope";
1409 constexpr const char* reduce_scope = "reduce_scope";
1411 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1413 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1415 constexpr const char* pragma_scope_prefix = "pragma_";
1417 constexpr const char* pragma_import_c = "pragma_import_c";
1419 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1421 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1426 constexpr const char* prefetch_scope = "prefetch_scope";
1433 constexpr const char* layout_transforms = "layout_transforms";
1441 constexpr const char* axis_separators = "axis_separators";
1445 constexpr const char* double_buffer_scope = "double_buffer_scope";
1449 constexpr const char* double_buffer_write = "double_buffer_write";
1451 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1453 constexpr const char* scan_update_scope = "scan_update_scope";
1455 constexpr const char* scan_init_scope = "scan_init_scope";
1462 constexpr const char* buffer_dim_align = "buffer_dim_align";
1464 constexpr const char* buffer_bound = "buffer_bound";
1474 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1475 // Pipeline related attributes
1477 constexpr const char* channel_read_scope = "channel_read_scope";
1479 constexpr const char* channel_read_advance = "channel_read_advance";
1481 constexpr const char* channel_write_scope = "channel_write_scope";
1483 constexpr const char* channel_write_advance = "channel_write_advance";
1485 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1487 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1488 
1492 constexpr const char* device_scope = "device_scope";
1493 
1497 constexpr const char* async_scope = "async_scope";
1498 
1516 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1517 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1518 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1519 
1523 constexpr const char* fragment_shape = "fragment_shape";
1524 
1528 constexpr const char* fragment_layout = "fragment_layout";
1529 
1533 constexpr const char* hand_threaded = "hand_threaded";
1534 
1542 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1543 
1547 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1548 
1550 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1551 
1553 constexpr const char* software_pipeline_order = "software_pipeline_order";
1554 
1559 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1560 
1562 constexpr const char* layout_free_buffers = "layout_free_buffers";
1563 
1565 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1566 
1568 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1569 
1574 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1575 
1578  "meta_schedule.thread_extent_low_inclusive";
1579 
1582  "meta_schedule.thread_extent_high_inclusive";
1583 
1586  "meta_schedule.random_compute_producer";
1587 
1589 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1590 
1592 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1593 
1595 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1596 
1598 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1599 
1601 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1602 
1604 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1608 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1609 
1614 constexpr const char* warp_execution = "warp_execution";
1615 
1621 inline bool IsPragmaKey(const std::string& attr_key) {
1622  return attr_key.compare(0, 7, "pragma_") == 0;
1623 }
1624 
1625 } // namespace attr
1632 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1633 
1634 // overload printing of for type.
1635 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1636 
1637 // inline implementations
1638 inline const char* ForKind2String(ForKind t) {
1639  switch (t) {
1640  case ForKind::kSerial:
1641  return "serial";
1642  case ForKind::kParallel:
1643  return "parallel";
1644  case ForKind::kVectorized:
1645  return "vectorized";
1646  case ForKind::kUnrolled:
1647  return "unroll";
1649  return "thread_binding";
1650  }
1651  LOG(FATAL) << "Unknown ForKind" << t;
1652 }
1653 
1654 } // namespace tir
1655 } // namespace tvm
1656 #endif // TVM_TIR_STMT_H_
tvm::Span Span
Definition: base.h:65
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:1045
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1417
Managed reference to StoreNode.
Definition: stmt.h:271
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:835
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:429
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:118
PrimExpr index
The index locations to be stored.
Definition: stmt.h:239
String attr_key
the type key of the attribute
Definition: stmt.h:123
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1487
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:408
Buffer buffer
The buffer variable.
Definition: stmt.h:293
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:881
constexpr const char * device_type
The device type.
Definition: stmt.h:1405
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1076
A prefetch hint for a buffer.
Definition: stmt.h:1073
Base node of all statements.
Definition: stmt.h:38
Declare a buffer that can be used in the body.
Definition: stmt.h:692
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1455
Managed reference to BlockNode.
Definition: stmt.h:1308
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1090
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1262
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:297
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:772
PrimExpr value
The value to be stored.
Definition: stmt.h:237
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:129
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1325
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1479
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:612
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1401
PrimExpr value
The value to be binded.
Definition: stmt.h:70
Stmt body
The body of realization.
Definition: stmt.h:472
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to...
Definition: stmt.h:966
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:187
Managed reference to PrefetchNode.
Definition: stmt.h:1107
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:256
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset.
Definition: stmt.h:1462
Managed reference to AllocateConstNode.
Definition: stmt.h:677
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:417
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:797
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1604
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:50
constexpr const char * coproc_scope
Mark region is processed by a co-proccesor.
Definition: stmt.h:1379
Managed reference to ProducerRealizeNode.
Definition: stmt.h:507
Managed reference to IfThenElseNode.
Definition: stmt.h:863
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1601
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:352
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1464
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1559
StmtNode(Span span)
Definition: stmt.h:47
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:306
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:730
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt...
Definition: stmt.h:1474
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
bool SEqualReduce(const StoreNode *other, SEqualReducer equal) const
Definition: stmt.h:251
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1595
a named variable in TIR
Definition: var.h:88
Evaluate(int value, Span span=Span())
Definition: stmt.h:906
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:180
Var var
The variable.
Definition: stmt.h:68
IfThenElse statment.
Definition: stmt.h:828
Buffer buffer
The buffer being declared.
Definition: stmt.h:695
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:616
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:844
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1039
DataProducer producer
The producer that produces the data.
Definition: stmt.h:466
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1516
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1397
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1477
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size...
Definition: stmt.h:1614
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1134
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1181
ForKind kind
The kind of the for loop.
Definition: stmt.h:959
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1581
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1419
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1621
void Prefetch(Buffer buffer, Array< Range > bounds)
The prefetch hint for a buffer.
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:549
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:977
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:364
Managed reference to ForNode.
Definition: stmt.h:1012
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:476
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1244
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:97
Block block
The block to be realized.
Definition: stmt.h:1335
DataType dtype
The type of the buffer.
Definition: stmt.h:614
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1246
static constexpr const char * _type_key
Definition: stmt.h:49
Managed reference to BufferRegionNode.
Definition: stmt.h:1149
base class of all object containers.
Definition: object.h:167
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array< PrimExpr > extents, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate constant node.
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing...
Definition: stmt.h:1542
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1608
Stmt body
The body statement to be executed.
Definition: stmt.h:127
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:733
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1392
PrimExpr value
The value to be stored.
Definition: stmt.h:413
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:470
void operator()(size_t i, const Stmt &stmt) const
Definition: stmt.h:801
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1121
Var buffer_var
The buffer variable.
Definition: stmt.h:522
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1411
Managed reference to AssertStmtNode.
Definition: stmt.h:206
DataProducer producer
The producer to store the results into.
Definition: stmt.h:411
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1125
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1592
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1403
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:699
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:81
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1118
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:975
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1451
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:415
Buffer buffer
The buffer variable.
Definition: stmt.h:348
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1086
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
size_t size() const
Definition: stmt.h:768
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:52
ObjectRef node
this is attribute about certain node
Definition: stmt.h:121
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:474
DataType dtype
The type of the buffer.
Definition: stmt.h:524
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:463
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1562
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1517
default semantics – serial execution.
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:528
AllocateFrame Allocate(Array< PrimExpr > extents, DataType dtype, String storage_scope="", Optional< PrimExpr > condition=NullOpt, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate node.
Definition: span.h:115
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1483
PrimExpr predicate
The predicate to mask which lanes would be stored.
Definition: stmt.h:241
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1492
Parallel execution on CPU.
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:790
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1348
BufferRegion source
The source buffer region.
Definition: stmt.h:1186
size_t size() const
Definition: array.h:420
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1123
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
Stmt body
The body to be executed.
Definition: stmt.h:618
bool defined() const
Definition: object.h:544
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:299
String name_hint
The name_hint of the block.
Definition: stmt.h:1250
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:955
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent...
Definition: stmt.h:1384
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:988
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:424
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true...
Definition: stmt.h:1333
PrimExpr condition
The condition.
Definition: stmt.h:831
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:627
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1528
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1409
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:569
A While loop.
Definition: stmt.h:1032
Stmt body
The body of the for loop.
Definition: stmt.h:961
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1565
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:751
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:142
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1577
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1426
Stmt body
The body of the while loop.
Definition: stmt.h:1037
Stmt body
The body of realization.
Definition: stmt.h:354
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1433
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1386
Container of all statements.
Definition: stmt.h:57
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:178
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1188
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:742
void BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices)
Store data in a buffer.
WhileFrame While(PrimExpr condition)
Create a while loop.
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1585
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:51
The loop variable is bound to a thread in an environment. In the final stage of lowering, the loop is simply removed and the loop variable is mapped to the corresponding context thread.
Reference to string objects.
Definition: string.h:97
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1497
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1445
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1337
Managed reference to BufferRealizeNode.
Definition: stmt.h:389
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Allocate a buffer that can be used in body.
Definition: stmt.h:519
The loop is vectorized.
Definition: var.h:230
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:837
size_t size() const
Definition: stmt.h:736
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:878
The execution is unrolled.
Definition: var.h:226
PrimExpr condition
The termination condition.
Definition: stmt.h:1035
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1413
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1197
Managed reference to AllocateNode.
Definition: stmt.h:588
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1049
A block is a basic schedule unit in TIR.
Definition: stmt.h:1241
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:243
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1375
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:607
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:994
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1264
Managed reference to BlockRealizeNode.
Definition: stmt.h:1362
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1212
Managed reference to DeclBufferNode.
Definition: stmt.h:719
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
ForKind
The kind of the loop.
Definition: stmt.h:919
Allocate a buffer that can be used in body.
Definition: stmt.h:601
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:888
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:625
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:377
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1553
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:892
Store value to the buffer.
Definition: stmt.h:232
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:709
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:705
Managed reference to DataProducerNode.
Definition: buffer.h:293
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1441
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:369
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:849
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1193
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1080
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:125
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1518
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1289
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:555
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1550
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
Store value to the high dimension buffer.
Definition: stmt.h:290
Managed reference to BufferStoreNode.
Definition: stmt.h:325
Region bounds
Bounds to be realized.
Definition: stmt.h:468
Var buffer_var
The buffer variable.
Definition: stmt.h:604
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:833
Managed reference to WhileNode.
Definition: stmt.h:1062
void operator()(size_t i, const T &seq) const
Definition: stmt.h:811
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:168
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:485
Managed reference to ProducerStoreNode.
Definition: stmt.h:443
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:356
PrimExpr condition
Condition to be checked.
Definition: stmt.h:171
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1589
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1343
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1449
Stmt body
The body to be executed.
Definition: stmt.h:530
DeclBufferFrame DeclBuffer(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration frame.
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn&#39;t need syncs inserted.
Definition: stmt.h:1533
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1598
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1248
Stmt body
The body block.
Definition: stmt.h:72
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1407
Stmt body
The body of the block.
Definition: stmt.h:1252
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1485
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:137
A for loop, with poissible type annotations.
Definition: stmt.h:950
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1328
Var loop_var
The loop variable.
Definition: stmt.h:953
void Evaluate(PrimExpr value)
Evaluate the input expression.
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:638
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:311
Flattener(Array< Stmt > *seq)
Definition: stmt.h:799
PrimExpr value
The value to be stored.
Definition: stmt.h:295
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1280
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:86
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1481
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1096
Let binding, bind var to value, then run body.
Definition: stmt.h:65
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:173
Reference to PrimExprNode.
Definition: expr.h:112
Sequence statement.
Definition: stmt.h:758
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:539
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1574
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:644
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1130
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1568
Managed reference to EvaluateNode.
Definition: stmt.h:902
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1523
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:957
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
const char * ForKind2String(ForKind t)
Definition: stmt.h:1638
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1415
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:350
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1453
Buffer buffer
The target buffer.
Definition: stmt.h:1184
Stmt body
The body to be executed.
Definition: stmt.h:697
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:345
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1268
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:491
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:740
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:192
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:74
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Managed reference to AttrStmtNode.
Definition: stmt.h:157
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1078
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1266
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1421
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block...
Definition: stmt.h:1260
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:537
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:526
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1399
Managed reference to LetStmtNode.
Definition: stmt.h:100
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:883
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:747
Var buffer_var
The buffer variable.
Definition: stmt.h:235
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1377
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:658
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1547