tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
49  static constexpr const char* _type_key = "tir.Stmt";
50  static constexpr const bool _type_has_method_sequal_reduce = true;
51  static constexpr const bool _type_has_method_shash_reduce = true;
52  static constexpr const uint32_t _type_child_slots = 15;
54 };
55 
57 class Stmt : public ObjectRef {
58  public:
60 };
61 
65 class LetStmtNode : public StmtNode {
66  public:
73 
75  v->Visit("var", &var);
76  v->Visit("value", &value);
77  v->Visit("body", &body);
78  v->Visit("span", &span);
79  }
80 
81  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
82  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
83  equal(body, other->body);
84  }
85 
86  void SHashReduce(SHashReducer hash_reduce) const {
87  hash_reduce.DefHash(var);
88  hash_reduce(value);
89  hash_reduce(body);
90  }
91 
92  static constexpr const char* _type_key = "tir.LetStmt";
94 };
95 
100 class LetStmt : public Stmt {
101  public:
102  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
103 
105 };
106 
117 class AttrStmtNode : public StmtNode {
118  public:
127 
129  v->Visit("node", &node);
130  v->Visit("attr_key", &attr_key);
131  v->Visit("value", &value);
132  v->Visit("body", &body);
133  v->Visit("span", &span);
134  }
135 
136  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
137  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
138  equal(value, other->value) && equal(body, other->body);
139  }
140 
141  void SHashReduce(SHashReducer hash_reduce) const {
142  hash_reduce(node);
143  hash_reduce(attr_key);
144  hash_reduce(value);
145  hash_reduce(body);
146  }
147 
148  static constexpr const char* _type_key = "tir.AttrStmt";
150 };
151 
156 class AttrStmt : public Stmt {
157  public:
158  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
159 
161 };
162 
166 class AssertStmtNode : public StmtNode {
167  public:
177 
179  v->Visit("condition", &condition);
180  v->Visit("message", &message);
181  v->Visit("body", &body);
182  v->Visit("span", &span);
183  }
184 
185  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
186  return equal(condition, other->condition) && equal(message, other->message) &&
187  equal(body, other->body);
188  }
189 
190  void SHashReduce(SHashReducer hash_reduce) const {
191  hash_reduce(condition);
192  hash_reduce(message);
193  hash_reduce(body);
194  }
195 
196  static constexpr const char* _type_key = "tir.AssertStmt";
198 };
199 
204 class AssertStmt : public Stmt {
205  public:
206  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
207 
209 };
210 
229 class StoreNode : public StmtNode {
230  public:
239 
241  v->Visit("buffer_var", &buffer_var);
242  v->Visit("value", &value);
243  v->Visit("index", &index);
244  v->Visit("predicate", &predicate);
245  v->Visit("span", &span);
246  }
247 
248  bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
249  return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
250  equal(index, other->index) && equal(predicate, other->predicate);
251  }
252 
253  void SHashReduce(SHashReducer hash_reduce) const {
254  hash_reduce(buffer_var);
255  hash_reduce(value);
256  hash_reduce(index);
257  hash_reduce(predicate);
258  }
259 
260  static constexpr const char* _type_key = "tir.Store";
262 };
263 
268 class Store : public Stmt {
269  public:
270  TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
271  Span span = Span());
272 
274 };
275 
286 class BufferStoreNode : public StmtNode {
287  public:
294 
296  v->Visit("buffer", &buffer);
297  v->Visit("value", &value);
298  v->Visit("indices", &indices);
299  v->Visit("span", &span);
300  }
301 
302  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
303  return equal(buffer, other->buffer) && equal(value, other->value) &&
304  equal(indices, other->indices);
305  }
306 
307  void SHashReduce(SHashReducer hash_reduce) const {
308  hash_reduce(buffer);
309  hash_reduce(value);
310  hash_reduce(indices);
311  }
312 
313  static constexpr const char* _type_key = "tir.BufferStore";
315 };
316 
321 class BufferStore : public Stmt {
322  public:
323  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
324  Span span = Span());
325 
328 };
329 
341 class BufferRealizeNode : public StmtNode {
342  public:
351 
353  v->Visit("buffer", &buffer);
354  v->Visit("bounds", &bounds);
355  v->Visit("condition", &condition);
356  v->Visit("body", &body);
357  v->Visit("span", &span);
358  }
359 
360  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
361  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
362  equal(condition, other->condition) && equal(body, other->body);
363  }
364 
365  void SHashReduce(SHashReducer hash_reduce) const {
366  hash_reduce(buffer);
367  hash_reduce(bounds);
368  hash_reduce(condition);
369  hash_reduce(body);
370  }
371 
372  BufferRealizeNode() = default;
373  BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
374  Span span = Span())
375  : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {}
376 
377  static constexpr const char* _type_key = "tir.BufferRealize";
379 };
380 
385 class BufferRealize : public Stmt {
386  public:
387  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
388  Span span = Span());
389 
392 };
393 
404 class ProducerStoreNode : public StmtNode {
405  public:
412 
414  v->Visit("producer", &producer);
415  v->Visit("value", &value);
416  v->Visit("indices", &indices);
417  v->Visit("span", &span);
418  }
419 
420  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
421  return equal(producer, other->producer) && equal(value, other->value) &&
422  equal(indices, other->indices);
423  }
424 
425  void SHashReduce(SHashReducer hash_reduce) const {
426  hash_reduce(producer);
427  hash_reduce(value);
428  hash_reduce(indices);
429  }
430 
431  static constexpr const char* _type_key = "tir.ProducerStore";
433 };
434 
439 class ProducerStore : public Stmt {
440  public:
441  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
442  Span span = Span());
443 
445 };
446 
459  public:
470 
472  v->Visit("producer", &producer);
473  v->Visit("bounds", &bounds);
474  v->Visit("condition", &condition);
475  v->Visit("body", &body);
476  v->Visit("storage_scope", &storage_scope);
477  v->Visit("span", &span);
478  }
479 
481  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
482  equal(condition, other->condition) && equal(body, other->body) &&
483  equal(storage_scope, other->storage_scope);
484  }
485 
486  void SHashReduce(SHashReducer hash_reduce) const {
487  hash_reduce(producer);
488  hash_reduce(bounds);
489  hash_reduce(condition);
490  hash_reduce(body);
491  hash_reduce(storage_scope);
492  }
493 
494  static constexpr const char* _type_key = "tir.ProducerRealize";
496 };
497 
502 class ProducerRealize : public Stmt {
503  public:
504  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
505  String storage_scope = "", Span span = Span());
506 
508 };
509 
513 class AllocateNode : public StmtNode {
514  public:
532 
534  v->Visit("buffer_var", &buffer_var);
535  v->Visit("dtype", &dtype);
536  v->Visit("extents", &extents);
537  v->Visit("condition", &condition);
538  v->Visit("body", &body);
539  v->Visit("annotations", &annotations);
540  v->Visit("span", &span);
541  }
542 
543  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
544  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
545  equal(extents, other->extents) && equal(condition, other->condition) &&
546  equal(body, other->body) && equal(annotations, other->annotations);
547  }
548 
549  void SHashReduce(SHashReducer hash_reduce) const {
550  hash_reduce.DefHash(buffer_var);
551  hash_reduce(dtype);
552  hash_reduce(extents);
553  hash_reduce(condition);
554  hash_reduce(body);
555  hash_reduce(annotations);
556  }
557 
563  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
570  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
571 
572  static constexpr const char* _type_key = "tir.Allocate";
573  static constexpr const bool _type_has_method_sequal_reduce = true;
574  static constexpr const bool _type_has_method_shash_reduce = true;
576 };
577 
582 class Allocate : public Stmt {
583  public:
584  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
585  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
586  Span span = Span());
587 
590 };
591 
595 class AllocateConstNode : public StmtNode {
596  public:
620 
622  v->Visit("buffer_var", &buffer_var);
623  v->Visit("data", &data);
624  v->Visit("irmod_storage_idx", &irmod_storage_idx);
625  v->Visit("dtype", &dtype);
626  v->Visit("extents", &extents);
627  v->Visit("body", &body);
628  v->Visit("annotations", &annotations);
629  v->Visit("span", &span);
630  }
631 
632  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
633  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
634  equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
635  equal(annotations, other->annotations);
636  }
637 
638  void SHashReduce(SHashReducer hash_reduce) const {
639  hash_reduce.DefHash(buffer_var);
640  hash_reduce(dtype);
641  hash_reduce(extents);
642  hash_reduce(body);
643  hash_reduce(annotations);
644  hash_reduce(data);
645  }
646 
652  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
659  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
660 
661  static constexpr const char* _type_key = "tir.AllocateConst";
662  static constexpr const bool _type_has_method_sequal_reduce = true;
663  static constexpr const bool _type_has_method_shash_reduce = true;
665 };
666 
671 class AllocateConst : public Stmt {
672  public:
673  /* The constructor to create a IRNode with constant data
674  * depending on the type of ObjectRef, it will either
675  * create AllocateConstNode with irmod_storage_idx or data
676  */
677  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
678  ObjectRef data_or_idx, Stmt body,
680  Span span = Span());
682 };
683 
685 class DeclBufferNode : public StmtNode {
686  public:
691 
693  v->Visit("buffer", &buffer);
694  v->Visit("body", &body);
695  v->Visit("span", &span);
696  }
697 
698  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
699  return equal(buffer, other->buffer) && equal(body, other->body);
700  }
701 
702  void SHashReduce(SHashReducer hash_reduce) const {
703  hash_reduce(buffer);
704  hash_reduce(body);
705  }
706 
707  static constexpr const char* _type_key = "tir.DeclBuffer";
709 };
710 
712 class DeclBuffer : public Stmt {
713  public:
714  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
717 };
718 
723 class SeqStmtNode : public StmtNode {
724  public:
727 
729  size_t size() const { return seq.size(); }
733  Stmt operator[](size_t index) const { return seq[index]; }
734 
736  v->Visit("seq", &seq);
737  v->Visit("span", &span);
738  }
739 
740  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
741  return equal(seq, other->seq);
742  }
743 
744  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
745 
746  static constexpr const char* _type_key = "tir.SeqStmt";
748 };
749 
751 class SeqStmt : public Stmt {
752  public:
758  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
759 
761  size_t size() const { return operator->()->size(); }
765  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
782  template <typename... Args>
783  static Stmt Flatten(Args&&... seq_args) {
784  Array<Stmt> seq;
785  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
786  if (seq.size() == 1) return seq[0];
787  return SeqStmt(seq);
788  }
790  class Flattener {
791  public:
792  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
793 
794  void operator()(size_t i, const Stmt& stmt) const {
795  if (!stmt.defined()) return;
796  if (auto* op = stmt.as<SeqStmtNode>()) {
797  operator()(0, op->seq);
798  } else {
799  seq_->push_back(stmt);
800  }
801  }
802 
803  template <typename T>
804  void operator()(size_t i, const T& seq) const {
805  for (auto v : seq) {
806  this->operator()(0, v);
807  }
808  }
809 
810  private:
811  Array<Stmt>* seq_;
812  };
813 
815 };
816 
820 class IfThenElseNode : public StmtNode {
821  public:
828 
830  v->Visit("condition", &condition);
831  v->Visit("then_case", &then_case);
832  v->Visit("else_case", &else_case);
833  v->Visit("span", &span);
834  }
835 
836  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
837  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
838  equal(else_case, other->else_case);
839  }
840 
841  void SHashReduce(SHashReducer hash_reduce) const {
842  hash_reduce(condition);
843  hash_reduce(then_case);
844  hash_reduce(else_case);
845  }
846 
847  static constexpr const char* _type_key = "tir.IfThenElse";
849 };
850 
855 class IfThenElse : public Stmt {
856  public:
857  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(),
858  Span span = Span());
859 
861 };
862 
869 class EvaluateNode : public StmtNode {
870  public:
873 
875  v->Visit("value", &value);
876  v->Visit("span", &span);
877  }
878 
879  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
880  return equal(value, other->value);
881  }
882 
883  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
884 
885  static constexpr const char* _type_key = "tir.Evaluate";
887 };
888 
893 class Evaluate : public Stmt {
894  public:
895  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
896 
897  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
898 
900 };
901 
909 enum class ForKind : int {
911  kSerial = 0,
913  kParallel = 1,
918  kVectorized = 2,
920  kUnrolled = 3,
927  kThreadBinding = 4
928 };
929 
940 class ForNode : public StmtNode {
941  public:
966 
968  v->Visit("loop_var", &loop_var);
969  v->Visit("min", &min);
970  v->Visit("extent", &extent);
971  v->Visit("kind", &kind);
972  v->Visit("body", &body);
973  v->Visit("thread_binding", &thread_binding);
974  v->Visit("annotations", &annotations);
975  v->Visit("span", &span);
976  }
977 
978  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
979  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
980  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
981  equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
982  }
983 
984  void SHashReduce(SHashReducer hash_reduce) const {
985  hash_reduce.DefHash(loop_var);
986  hash_reduce(min);
987  hash_reduce(extent);
988  hash_reduce(kind);
989  hash_reduce(body);
990  hash_reduce(thread_binding);
991  hash_reduce(annotations);
992  }
993 
994  static constexpr const char* _type_key = "tir.For";
996 };
997 
1002 class For : public Stmt {
1003  public:
1004  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
1005  Optional<IterVar> thread_binding = NullOpt,
1007 
1010 };
1011 
1022 class WhileNode : public StmtNode {
1023  public:
1028 
1030  v->Visit("condition", &condition);
1031  v->Visit("body", &body);
1032  v->Visit("span", &span);
1033  }
1034 
1035  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
1036  return equal(condition, other->condition) && equal(body, other->body);
1037  }
1038 
1039  void SHashReduce(SHashReducer hash_reduce) const {
1040  hash_reduce(condition);
1041  hash_reduce(body);
1042  }
1043 
1044  static constexpr const char* _type_key = "tir.While";
1046 };
1047 
1052 class While : public Stmt {
1053  public:
1054  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1055 
1057 };
1058 
1062 class PrefetchNode : public StmtNode {
1063  public:
1068 
1070  v->Visit("buffer", &buffer);
1071  v->Visit("bounds", &bounds);
1072  v->Visit("span", &span);
1073  }
1074 
1075  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1076  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1077  }
1078 
1079  void SHashReduce(SHashReducer hash_reduce) const {
1080  hash_reduce(buffer);
1081  hash_reduce(bounds);
1082  }
1083 
1084  PrefetchNode() = default;
1086  : StmtNode(span), buffer(buffer), bounds(bounds) {}
1087 
1088  static constexpr const char* _type_key = "tir.Prefetch";
1090 };
1091 
1096 class Prefetch : public Stmt {
1097  public:
1098  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1099 
1101 };
1102 
1106 class BufferRegionNode : public Object {
1107  public:
1112 
1114  v->Visit("buffer", &buffer);
1115  v->Visit("region", &region);
1116  }
1117 
1118  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1119  return equal(buffer, other->buffer) && equal(region, other->region);
1120  }
1121 
1122  void SHashReduce(SHashReducer hash_reduce) const {
1123  hash_reduce(buffer);
1124  hash_reduce(region);
1125  }
1126 
1127  static constexpr const char* _type_key = "tir.BufferRegion";
1128  static constexpr const bool _type_has_method_sequal_reduce = true;
1129  static constexpr const bool _type_has_method_shash_reduce = true;
1131 };
1132 
1137 class BufferRegion : public ObjectRef {
1138  public:
1139  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1140 
1146  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1147 
1154  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1155 
1158 };
1159 
1170  public:
1175 
1177  v->Visit("buffer", &buffer);
1178  v->Visit("source", &source);
1179  }
1180 
1182  return equal(buffer, other->buffer) && equal(source, other->source);
1183  }
1184 
1185  void SHashReduce(SHashReducer hash_reduce) const {
1186  hash_reduce(buffer);
1187  hash_reduce(source);
1188  }
1189 
1190  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1191  static constexpr const bool _type_has_method_sequal_reduce = true;
1192  static constexpr const bool _type_has_method_shash_reduce = true;
1194 };
1195 
1201  public:
1202  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1203 
1205 };
1206 
1228 class BlockNode : public StmtNode {
1229  public:
1254 
1256  v->Visit("iter_vars", &iter_vars);
1257  v->Visit("reads", &reads);
1258  v->Visit("writes", &writes);
1259  v->Visit("name_hint", &name_hint);
1260  v->Visit("body", &body);
1261  v->Visit("init", &init);
1262  v->Visit("alloc_buffers", &alloc_buffers);
1263  v->Visit("match_buffers", &match_buffers);
1264  v->Visit("annotations", &annotations);
1265  }
1266 
1267  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1268  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1269  return equal.DefEqual(iter_vars, other->iter_vars) &&
1270  equal(alloc_buffers, other->alloc_buffers) &&
1271  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1272  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1273  equal(annotations, other->annotations);
1274  }
1275 
1276  void SHashReduce(SHashReducer hash_reduce) const {
1277  hash_reduce.DefHash(iter_vars);
1278  hash_reduce(alloc_buffers);
1279  hash_reduce(match_buffers);
1280  hash_reduce(reads);
1281  hash_reduce(writes);
1282  hash_reduce(body);
1283  hash_reduce(init);
1284  hash_reduce(annotations);
1285  }
1286 
1287  static constexpr const char* _type_key = "tir.Block";
1289 };
1290 
1295 class Block : public Stmt {
1296  public:
1297  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1298  Array<BufferRegion> writes, String name_hint, Stmt body,
1299  Optional<Stmt> init = NullOpt,
1300  Array<Buffer> alloc_buffers = Array<Buffer>(),
1303  Span span = Span());
1304 
1307 };
1308 
1312 class BlockRealizeNode : public StmtNode {
1313  public:
1323 
1325  v->Visit("iter_values", &iter_values);
1326  v->Visit("predicate", &predicate);
1327  v->Visit("block", &block);
1328  }
1329 
1330  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1331  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1332  equal(block, other->block);
1333  }
1334 
1335  void SHashReduce(SHashReducer hash_reduce) const {
1336  hash_reduce(iter_values);
1337  hash_reduce(predicate);
1338  hash_reduce(block);
1339  }
1340 
1341  static constexpr const char* _type_key = "tir.BlockRealize";
1343 };
1344 
1349 class BlockRealize : public Stmt {
1350  public:
1351  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1352  Span span = Span());
1353 
1356 };
1357 
1359 namespace attr {
1360 // The above attr does not pass to ir stage.
1362 constexpr const char* thread_extent = "thread_extent";
1364 constexpr const char* virtual_thread = "virtual_thread";
1366 constexpr const char* coproc_scope = "coproc_scope";
1371 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1373 constexpr const char* volatile_scope = "volatile_scope";
1379 constexpr const char* extern_scope = "extern_scope";
1384 constexpr const char* compute_scope = "compute_scope";
1386 constexpr const char* storage_alignment = "storage_alignment";
1388 constexpr const char* realize_scope = "realize_scope";
1390 constexpr const char* device_id = "device_id";
1392 constexpr const char* device_type = "device_type";
1394 constexpr const char* loop_scope = "loop_scope";
1396 constexpr const char* reduce_scope = "reduce_scope";
1398 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1400 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1402 constexpr const char* pragma_scope_prefix = "pragma_";
1404 constexpr const char* pragma_import_c = "pragma_import_c";
1406 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1408 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1413 constexpr const char* prefetch_scope = "prefetch_scope";
1420 constexpr const char* layout_transforms = "layout_transforms";
1428 constexpr const char* axis_separators = "axis_separators";
1432 constexpr const char* double_buffer_scope = "double_buffer_scope";
1436 constexpr const char* double_buffer_write = "double_buffer_write";
1438 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1440 constexpr const char* scan_update_scope = "scan_update_scope";
1442 constexpr const char* scan_init_scope = "scan_init_scope";
1449 constexpr const char* buffer_dim_align = "buffer_dim_align";
1451 constexpr const char* buffer_bound = "buffer_bound";
1461 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1462 // Pipeline related attributes
1464 constexpr const char* channel_read_scope = "channel_read_scope";
1466 constexpr const char* channel_read_advance = "channel_read_advance";
1468 constexpr const char* channel_write_scope = "channel_write_scope";
1470 constexpr const char* channel_write_advance = "channel_write_advance";
1472 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1474 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1475 
1479 constexpr const char* device_scope = "device_scope";
1480 
1484 constexpr const char* async_scope = "async_scope";
1485 
1503 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1504 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1505 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1506 
1510 constexpr const char* fragment_shape = "fragment_shape";
1511 
1515 constexpr const char* fragment_layout = "fragment_layout";
1516 
1520 constexpr const char* hand_threaded = "hand_threaded";
1521 
1529 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1530 
1534 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1535 
1537 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1538 
1540 constexpr const char* software_pipeline_order = "software_pipeline_order";
1541 
1546 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1547 
1549 constexpr const char* layout_free_buffers = "layout_free_buffers";
1550 
1552 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1553 
1555 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1556 
1561 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1562 
1565  "meta_schedule.thread_extent_low_inclusive";
1566 
1569  "meta_schedule.thread_extent_high_inclusive";
1570 
1573  "meta_schedule.random_compute_producer";
1574 
1576 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1577 
1579 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1580 
1582 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1583 
1585 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1586 
1588 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1589 
1591 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1595 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1596 
1601 constexpr const char* warp_execution = "warp_execution";
1602 
1608 inline bool IsPragmaKey(const std::string& attr_key) {
1609  return attr_key.compare(0, 7, "pragma_") == 0;
1610 }
1611 
1612 } // namespace attr
1619 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1620 
1621 // overload printing of for type.
1622 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1623 
1624 // inline implementations
1625 inline const char* ForKind2String(ForKind t) {
1626  switch (t) {
1627  case ForKind::kSerial:
1628  return "serial";
1629  case ForKind::kParallel:
1630  return "parallel";
1631  case ForKind::kVectorized:
1632  return "vectorized";
1633  case ForKind::kUnrolled:
1634  return "unroll";
1636  return "thread_binding";
1637  }
1638  LOG(FATAL) << "Unknown ForKind" << t;
1639  return "Unknown";
1640 }
1641 
1642 } // namespace tir
1643 } // namespace tvm
1644 #endif // TVM_TIR_STMT_H_
tvm::Span Span
Definition: base.h:65
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:1035
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1404
Managed reference to StoreNode.
Definition: stmt.h:268
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:425
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
PrimExpr index
The index locations to be stored.
Definition: stmt.h:236
String attr_key
the type key of the attribute
Definition: stmt.h:122
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1474
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:404
Buffer buffer
The buffer variable.
Definition: stmt.h:289
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:872
constexpr const char * device_type
The device type.
Definition: stmt.h:1392
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1065
A prefetch hint for a buffer.
Definition: stmt.h:1062
Base node of all statements.
Definition: stmt.h:38
Declare a buffer that can be used in the body.
Definition: stmt.h:685
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1442
Managed reference to BlockNode.
Definition: stmt.h:1295
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1079
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1249
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:293
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:765
PrimExpr value
The value to be stored.
Definition: stmt.h:234
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:128
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1312
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1466
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:606
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1388
PrimExpr value
The value to be binded.
Definition: stmt.h:70
Stmt body
The body of realization.
Definition: stmt.h:467
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to...
Definition: stmt.h:956
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:185
Managed reference to PrefetchNode.
Definition: stmt.h:1096
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:253
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset.
Definition: stmt.h:1449
Managed reference to AllocateConstNode.
Definition: stmt.h:671
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:413
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:790
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1591
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:50
constexpr const char * coproc_scope
Mark region is processed by a co-proccesor.
Definition: stmt.h:1366
Managed reference to ProducerRealizeNode.
Definition: stmt.h:502
Managed reference to IfThenElseNode.
Definition: stmt.h:855
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1588
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:348
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1451
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1546
StmtNode(Span span)
Definition: stmt.h:47
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:302
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:723
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt...
Definition: stmt.h:1461
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
bool SEqualReduce(const StoreNode *other, SEqualReducer equal) const
Definition: stmt.h:248
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1582
a named variable in TIR
Definition: var.h:88
Evaluate(int value, Span span=Span())
Definition: stmt.h:897
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:178
Var var
The variable.
Definition: stmt.h:68
IfThenElse statment.
Definition: stmt.h:820
Buffer buffer
The buffer being declared.
Definition: stmt.h:688
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:610
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:836
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1029
DataProducer producer
The producer that produces the data.
Definition: stmt.h:461
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1503
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1384
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1464
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size...
Definition: stmt.h:1601
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array< PrimExpr > extents, Map< String, ObjectRef > annotations=NullValue< Map< String, ObjectRef >>())
The allocate constant node.
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1122
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1169
ForKind kind
The kind of the for loop.
Definition: stmt.h:949
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1568
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1406
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1608
void Prefetch(Buffer buffer, Array< Range > bounds)
The prefetch hint for a buffer.
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:543
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:967
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:360
Managed reference to ForNode.
Definition: stmt.h:1002
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:471
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1231
Block block
The block to be realized.
Definition: stmt.h:1322
DataType dtype
The type of the buffer.
Definition: stmt.h:608
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1233
static constexpr const char * _type_key
Definition: stmt.h:49
Managed reference to BufferRegionNode.
Definition: stmt.h:1137
base class of all object containers.
Definition: object.h:167
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing...
Definition: stmt.h:1529
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1595
Stmt body
The body statement to be executed.
Definition: stmt.h:126
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:726
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1379
PrimExpr value
The value to be stored.
Definition: stmt.h:409
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:465
void operator()(size_t i, const Stmt &stmt) const
Definition: stmt.h:794
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1109
Var buffer_var
The buffer variable.
Definition: stmt.h:516
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1398
Managed reference to AssertStmtNode.
Definition: stmt.h:204
DataProducer producer
The producer to store the results into.
Definition: stmt.h:407
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1113
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1579
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1390
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:692
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:81
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1106
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:965
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1438
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:411
Buffer buffer
The buffer variable.
Definition: stmt.h:344
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1075
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
size_t size() const
Definition: stmt.h:761
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:52
ObjectRef node
this is attribute about certain node
Definition: stmt.h:120
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:469
DataType dtype
The type of the buffer.
Definition: stmt.h:518
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:458
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1549
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1504
default semantics – serial execution.
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:522
AllocateFrame Allocate(Array< PrimExpr > extents, DataType dtype, String storage_scope="", Optional< PrimExpr > condition=NullOpt, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate node.
Definition: span.h:115
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1470
PrimExpr predicate
The predicate to mask which lanes would be stored.
Definition: stmt.h:238
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1479
Parallel execution on CPU.
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:783
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1335
BufferRegion source
The source buffer region.
Definition: stmt.h:1174
size_t size() const
Definition: array.h:418
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1111
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
Stmt body
The body to be executed.
Definition: stmt.h:612
bool defined() const
Definition: object.h:544
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:295
String name_hint
The name_hint of the block.
Definition: stmt.h:1237
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:945
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent...
Definition: stmt.h:1371
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:978
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:420
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true...
Definition: stmt.h:1320
PrimExpr condition
The condition.
Definition: stmt.h:823
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:621
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1515
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1396
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:563
A While loop.
Definition: stmt.h:1022
Stmt body
The body of the for loop.
Definition: stmt.h:951
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1552
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:744
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:141
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1564
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1413
Stmt body
The body of the while loop.
Definition: stmt.h:1027
Stmt body
The body of realization.
Definition: stmt.h:350
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1420
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1373
Container of all statements.
Definition: stmt.h:57
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:176
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1176
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:735
void BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices)
Store data in a buffer.
WhileFrame While(PrimExpr condition)
Create a while loop.
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1572
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:51
The loop variable is bound to a thread in an environment. In the final stage of lowering, the loop is simply removed and the loop variable is mapped to the corresponding context thread.
Reference to string objects.
Definition: string.h:97
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1484
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1432
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1324
Managed reference to BufferRealizeNode.
Definition: stmt.h:385
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Allocate a buffer that can be used in body.
Definition: stmt.h:513
The loop is vectorized.
Definition: var.h:230
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:829
size_t size() const
Definition: stmt.h:729
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:869
The execution is unrolled.
Definition: var.h:226
PrimExpr condition
The termination condition.
Definition: stmt.h:1025
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1400
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1185
Managed reference to AllocateNode.
Definition: stmt.h:582
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1039
A block is a basic schedule unit in TIR.
Definition: stmt.h:1228
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:240
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1362
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:601
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:984
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1251
Managed reference to BlockRealizeNode.
Definition: stmt.h:1349
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1200
Managed reference to DeclBufferNode.
Definition: stmt.h:712
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
ForKind
The kind of the loop.
Definition: stmt.h:909
Allocate a buffer that can be used in body.
Definition: stmt.h:595
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:879
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:619
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:373
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1540
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:883
Store value to the buffer.
Definition: stmt.h:229
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:702
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:698
Managed reference to DataProducerNode.
Definition: buffer.h:293
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1428
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:365
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:841
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1181
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1069
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:124
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1505
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1276
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:549
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1537
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
Store value to the high dimension buffer.
Definition: stmt.h:286
Managed reference to BufferStoreNode.
Definition: stmt.h:321
Region bounds
Bounds to be realized.
Definition: stmt.h:463
Var buffer_var
The buffer variable.
Definition: stmt.h:598
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:825
Managed reference to WhileNode.
Definition: stmt.h:1052
void operator()(size_t i, const T &seq) const
Definition: stmt.h:804
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:166
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:480
Managed reference to ProducerStoreNode.
Definition: stmt.h:439
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:352
PrimExpr condition
Condition to be checked.
Definition: stmt.h:169
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1576
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1330
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1436
Stmt body
The body to be executed.
Definition: stmt.h:524
DeclBufferFrame DeclBuffer(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration frame.
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn&#39;t need syncs inserted.
Definition: stmt.h:1520
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1585
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1235
Stmt body
The body block.
Definition: stmt.h:72
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1394
Stmt body
The body of the block.
Definition: stmt.h:1239
Stmt else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:827
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1472
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:136
A for loop, with poissible type annotations.
Definition: stmt.h:940
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1315
Var loop_var
The loop variable.
Definition: stmt.h:943
void Evaluate(PrimExpr value)
Evaluate the input expression.
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:632
std::ostream & operator<<(std::ostream &os, ForKind kind)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:307
Flattener(Array< Stmt > *seq)
Definition: stmt.h:792
PrimExpr value
The value to be stored.
Definition: stmt.h:291
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1267
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:86
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1468
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1085
Let binding, bind var to value, then run body.
Definition: stmt.h:65
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:171
Reference to PrimExprNode.
Definition: expr.h:112
Sequence statement.
Definition: stmt.h:751
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:533
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1561
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:638
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1118
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1555
Managed reference to EvaluateNode.
Definition: stmt.h:893
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1510
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:947
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
const char * ForKind2String(ForKind t)
Definition: stmt.h:1625
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1402
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:346
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1440
Buffer buffer
The target buffer.
Definition: stmt.h:1172
Stmt body
The body to be executed.
Definition: stmt.h:690
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:341
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1255
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:486
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:733
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:190
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:74
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Managed reference to AttrStmtNode.
Definition: stmt.h:156
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1067
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1253
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1408
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block...
Definition: stmt.h:1247
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:531
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:520
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1386
Managed reference to LetStmtNode.
Definition: stmt.h:100
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:874
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:740
Var buffer_var
The buffer variable.
Definition: stmt.h:232
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1364
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:652
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1534