tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
49  static constexpr const char* _type_key = "tir.Stmt";
50  static constexpr const bool _type_has_method_sequal_reduce = true;
51  static constexpr const bool _type_has_method_shash_reduce = true;
52  static constexpr const uint32_t _type_child_slots = 15;
54 };
55 
57 class Stmt : public ObjectRef {
58  public:
60 };
61 
65 class LetStmtNode : public StmtNode {
66  public:
73 
75  v->Visit("var", &var);
76  v->Visit("value", &value);
77  v->Visit("body", &body);
78  v->Visit("span", &span);
79  }
80 
81  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
82  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
83  equal(body, other->body);
84  }
85 
86  void SHashReduce(SHashReducer hash_reduce) const {
87  hash_reduce.DefHash(var);
88  hash_reduce(value);
89  hash_reduce(body);
90  }
91 
92  static constexpr const char* _type_key = "tir.LetStmt";
94 };
95 
100 class LetStmt : public Stmt {
101  public:
102  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
103 
105 };
106 
117 class AttrStmtNode : public StmtNode {
118  public:
127 
129  v->Visit("node", &node);
130  v->Visit("attr_key", &attr_key);
131  v->Visit("value", &value);
132  v->Visit("body", &body);
133  v->Visit("span", &span);
134  }
135 
136  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
137  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
138  equal(value, other->value) && equal(body, other->body);
139  }
140 
141  void SHashReduce(SHashReducer hash_reduce) const {
142  hash_reduce(node);
143  hash_reduce(attr_key);
144  hash_reduce(value);
145  hash_reduce(body);
146  }
147 
148  static constexpr const char* _type_key = "tir.AttrStmt";
150 };
151 
156 class AttrStmt : public Stmt {
157  public:
158  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
159 
161 };
162 
166 class AssertStmtNode : public StmtNode {
167  public:
177 
179  v->Visit("condition", &condition);
180  v->Visit("message", &message);
181  v->Visit("body", &body);
182  v->Visit("span", &span);
183  }
184 
185  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
186  return equal(condition, other->condition) && equal(message, other->message) &&
187  equal(body, other->body);
188  }
189 
190  void SHashReduce(SHashReducer hash_reduce) const {
191  hash_reduce(condition);
192  hash_reduce(message);
193  hash_reduce(body);
194  }
195 
196  static constexpr const char* _type_key = "tir.AssertStmt";
198 };
199 
204 class AssertStmt : public Stmt {
205  public:
206  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
207 
209 };
210 
229 class StoreNode : public StmtNode {
230  public:
239 
241  v->Visit("buffer_var", &buffer_var);
242  v->Visit("value", &value);
243  v->Visit("index", &index);
244  v->Visit("predicate", &predicate);
245  v->Visit("span", &span);
246  }
247 
248  bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
249  return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
250  equal(index, other->index) && equal(predicate, other->predicate);
251  }
252 
253  void SHashReduce(SHashReducer hash_reduce) const {
254  hash_reduce(buffer_var);
255  hash_reduce(value);
256  hash_reduce(index);
257  hash_reduce(predicate);
258  }
259 
260  static constexpr const char* _type_key = "tir.Store";
262 };
263 
268 class Store : public Stmt {
269  public:
270  TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
271  Span span = Span());
272 
274 };
275 
286 class BufferStoreNode : public StmtNode {
287  public:
294 
296  v->Visit("buffer", &buffer);
297  v->Visit("value", &value);
298  v->Visit("indices", &indices);
299  v->Visit("span", &span);
300  }
301 
302  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
303  return equal(buffer, other->buffer) && equal(value, other->value) &&
304  equal(indices, other->indices);
305  }
306 
307  void SHashReduce(SHashReducer hash_reduce) const {
308  hash_reduce(buffer);
309  hash_reduce(value);
310  hash_reduce(indices);
311  }
312 
313  static constexpr const char* _type_key = "tir.BufferStore";
315 };
316 
321 class BufferStore : public Stmt {
322  public:
323  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
324  Span span = Span());
325 
328 };
329 
341 class BufferRealizeNode : public StmtNode {
342  public:
351 
353  v->Visit("buffer", &buffer);
354  v->Visit("bounds", &bounds);
355  v->Visit("condition", &condition);
356  v->Visit("body", &body);
357  v->Visit("span", &span);
358  }
359 
360  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
361  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
362  equal(condition, other->condition) && equal(body, other->body);
363  }
364 
365  void SHashReduce(SHashReducer hash_reduce) const {
366  hash_reduce(buffer);
367  hash_reduce(bounds);
368  hash_reduce(condition);
369  hash_reduce(body);
370  }
371 
372  BufferRealizeNode() = default;
373  BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
374  Span span = Span())
375  : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {}
376 
377  static constexpr const char* _type_key = "tir.BufferRealize";
379 };
380 
385 class BufferRealize : public Stmt {
386  public:
387  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
388  Span span = Span());
389 
391 };
392 
403 class ProducerStoreNode : public StmtNode {
404  public:
411 
413  v->Visit("producer", &producer);
414  v->Visit("value", &value);
415  v->Visit("indices", &indices);
416  v->Visit("span", &span);
417  }
418 
419  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
420  return equal(producer, other->producer) && equal(value, other->value) &&
421  equal(indices, other->indices);
422  }
423 
424  void SHashReduce(SHashReducer hash_reduce) const {
425  hash_reduce(producer);
426  hash_reduce(value);
427  hash_reduce(indices);
428  }
429 
430  static constexpr const char* _type_key = "tir.ProducerStore";
432 };
433 
438 class ProducerStore : public Stmt {
439  public:
440  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
441  Span span = Span());
442 
444 };
445 
458  public:
469 
471  v->Visit("producer", &producer);
472  v->Visit("bounds", &bounds);
473  v->Visit("condition", &condition);
474  v->Visit("body", &body);
475  v->Visit("storage_scope", &storage_scope);
476  v->Visit("span", &span);
477  }
478 
480  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
481  equal(condition, other->condition) && equal(body, other->body) &&
482  equal(storage_scope, other->storage_scope);
483  }
484 
485  void SHashReduce(SHashReducer hash_reduce) const {
486  hash_reduce(producer);
487  hash_reduce(bounds);
488  hash_reduce(condition);
489  hash_reduce(body);
490  hash_reduce(storage_scope);
491  }
492 
493  static constexpr const char* _type_key = "tir.ProducerRealize";
495 };
496 
501 class ProducerRealize : public Stmt {
502  public:
503  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
504  String storage_scope = "", Span span = Span());
505 
507 };
508 
512 class AllocateNode : public StmtNode {
513  public:
531 
533  v->Visit("buffer_var", &buffer_var);
534  v->Visit("dtype", &dtype);
535  v->Visit("extents", &extents);
536  v->Visit("condition", &condition);
537  v->Visit("body", &body);
538  v->Visit("annotations", &annotations);
539  v->Visit("span", &span);
540  }
541 
542  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
543  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
544  equal(extents, other->extents) && equal(condition, other->condition) &&
545  equal(body, other->body) && equal(annotations, other->annotations);
546  }
547 
548  void SHashReduce(SHashReducer hash_reduce) const {
549  hash_reduce.DefHash(buffer_var);
550  hash_reduce(dtype);
551  hash_reduce(extents);
552  hash_reduce(condition);
553  hash_reduce(body);
554  hash_reduce(annotations);
555  }
556 
562  int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
569  TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
570 
571  static constexpr const char* _type_key = "tir.Allocate";
573 };
574 
579 class Allocate : public Stmt {
580  public:
581  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
582  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
583  Span span = Span());
584 
586 };
587 
592 class SeqStmtNode : public StmtNode {
593  public:
596 
598  size_t size() const { return seq.size(); }
602  Stmt operator[](size_t index) const { return seq[index]; }
603 
605  v->Visit("seq", &seq);
606  v->Visit("span", &span);
607  }
608 
609  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
610  return equal(seq, other->seq);
611  }
612 
613  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
614 
615  static constexpr const char* _type_key = "tir.SeqStmt";
617 };
618 
620 class SeqStmt : public Stmt {
621  public:
627  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
628 
630  size_t size() const { return operator->()->size(); }
634  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
651  template <typename... Args>
652  static Stmt Flatten(Args&&... seq_args) {
653  Array<Stmt> seq;
654  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
655  if (seq.size() == 1) return seq[0];
656  return SeqStmt(seq);
657  }
659  class Flattener {
660  public:
661  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
662 
663  void operator()(size_t i, const Stmt& stmt) const {
664  if (!stmt.defined()) return;
665  if (auto* op = stmt.as<SeqStmtNode>()) {
666  operator()(0, op->seq);
667  } else {
668  seq_->push_back(stmt);
669  }
670  }
671 
672  template <typename T>
673  void operator()(size_t i, const T& seq) const {
674  for (auto v : seq) {
675  this->operator()(0, v);
676  }
677  }
678 
679  private:
680  Array<Stmt>* seq_;
681  };
682 
684 };
685 
689 class IfThenElseNode : public StmtNode {
690  public:
697 
699  v->Visit("condition", &condition);
700  v->Visit("then_case", &then_case);
701  v->Visit("else_case", &else_case);
702  v->Visit("span", &span);
703  }
704 
705  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
706  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
707  equal(else_case, other->else_case);
708  }
709 
710  void SHashReduce(SHashReducer hash_reduce) const {
711  hash_reduce(condition);
712  hash_reduce(then_case);
713  hash_reduce(else_case);
714  }
715 
716  static constexpr const char* _type_key = "tir.IfThenElse";
718 };
719 
724 class IfThenElse : public Stmt {
725  public:
726  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(),
727  Span span = Span());
728 
730 };
731 
738 class EvaluateNode : public StmtNode {
739  public:
742 
744  v->Visit("value", &value);
745  v->Visit("span", &span);
746  }
747 
748  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
749  return equal(value, other->value);
750  }
751 
752  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
753 
754  static constexpr const char* _type_key = "tir.Evaluate";
756 };
757 
762 class Evaluate : public Stmt {
763  public:
764  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
765 
766  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
767 
769 };
770 
778 enum class ForKind : int {
780  kSerial = 0,
782  kParallel = 1,
787  kVectorized = 2,
789  kUnrolled = 3,
796  kThreadBinding = 4
797 };
798 
809 class ForNode : public StmtNode {
810  public:
835 
837  v->Visit("loop_var", &loop_var);
838  v->Visit("min", &min);
839  v->Visit("extent", &extent);
840  v->Visit("kind", &kind);
841  v->Visit("body", &body);
842  v->Visit("thread_binding", &thread_binding);
843  v->Visit("annotations", &annotations);
844  v->Visit("span", &span);
845  }
846 
847  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
848  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
849  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
850  equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
851  }
852 
853  void SHashReduce(SHashReducer hash_reduce) const {
854  hash_reduce.DefHash(loop_var);
855  hash_reduce(min);
856  hash_reduce(extent);
857  hash_reduce(kind);
858  hash_reduce(body);
859  hash_reduce(thread_binding);
860  hash_reduce(annotations);
861  }
862 
863  static constexpr const char* _type_key = "tir.For";
865 };
866 
871 class For : public Stmt {
872  public:
873  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
874  Optional<IterVar> thread_binding = NullOpt,
876 
879 };
880 
891 class WhileNode : public StmtNode {
892  public:
897 
899  v->Visit("condition", &condition);
900  v->Visit("body", &body);
901  v->Visit("span", &span);
902  }
903 
904  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
905  return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
906  }
907 
908  void SHashReduce(SHashReducer hash_reduce) const {
909  hash_reduce.DefHash(condition);
910  hash_reduce.DefHash(body);
911  }
912 
913  static constexpr const char* _type_key = "tir.While";
915 };
916 
921 class While : public Stmt {
922  public:
923  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
924 
926 };
927 
931 class PrefetchNode : public StmtNode {
932  public:
937 
939  v->Visit("buffer", &buffer);
940  v->Visit("bounds", &bounds);
941  v->Visit("span", &span);
942  }
943 
944  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
945  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
946  }
947 
948  void SHashReduce(SHashReducer hash_reduce) const {
949  hash_reduce(buffer);
950  hash_reduce(bounds);
951  }
952 
953  PrefetchNode() = default;
955  : StmtNode(span), buffer(buffer), bounds(bounds) {}
956 
957  static constexpr const char* _type_key = "tir.Prefetch";
959 };
960 
965 class Prefetch : public Stmt {
966  public:
967  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
968 
970 };
971 
975 class BufferRegionNode : public Object {
976  public:
981 
983  v->Visit("buffer", &buffer);
984  v->Visit("region", &region);
985  }
986 
987  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
988  return equal(buffer, other->buffer) && equal(region, other->region);
989  }
990 
991  void SHashReduce(SHashReducer hash_reduce) const {
992  hash_reduce(buffer);
993  hash_reduce(region);
994  }
995 
996  static constexpr const char* _type_key = "tir.BufferRegion";
997  static constexpr const bool _type_has_method_sequal_reduce = true;
998  static constexpr const bool _type_has_method_shash_reduce = true;
1000 };
1001 
1006 class BufferRegion : public ObjectRef {
1007  public:
1008  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1009 
1015  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1016 
1023  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1024 
1027 };
1028 
1039  public:
1044 
1046  v->Visit("buffer", &buffer);
1047  v->Visit("source", &source);
1048  }
1049 
1051  return equal(buffer, other->buffer) && equal(source, other->source);
1052  }
1053 
1054  void SHashReduce(SHashReducer hash_reduce) const {
1055  hash_reduce(buffer);
1056  hash_reduce(source);
1057  }
1058 
1059  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1060  static constexpr const bool _type_has_method_sequal_reduce = true;
1061  static constexpr const bool _type_has_method_shash_reduce = true;
1063 };
1064 
1070  public:
1071  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1072 
1074 };
1075 
1097 class BlockNode : public StmtNode {
1098  public:
1123 
1125  v->Visit("iter_vars", &iter_vars);
1126  v->Visit("reads", &reads);
1127  v->Visit("writes", &writes);
1128  v->Visit("name_hint", &name_hint);
1129  v->Visit("body", &body);
1130  v->Visit("init", &init);
1131  v->Visit("alloc_buffers", &alloc_buffers);
1132  v->Visit("match_buffers", &match_buffers);
1133  v->Visit("annotations", &annotations);
1134  }
1135 
1136  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1137  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1138  return equal.DefEqual(iter_vars, other->iter_vars) &&
1139  equal(alloc_buffers, other->alloc_buffers) &&
1140  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1141  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1142  equal(annotations, other->annotations);
1143  }
1144 
1145  void SHashReduce(SHashReducer hash_reduce) const {
1146  hash_reduce.DefHash(iter_vars);
1147  hash_reduce(alloc_buffers);
1148  hash_reduce(match_buffers);
1149  hash_reduce(reads);
1150  hash_reduce(writes);
1151  hash_reduce(body);
1152  hash_reduce(init);
1153  hash_reduce(annotations);
1154  }
1155 
1156  static constexpr const char* _type_key = "tir.Block";
1158 };
1159 
1164 class Block : public Stmt {
1165  public:
1166  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1167  Array<BufferRegion> writes, String name_hint, Stmt body,
1168  Optional<Stmt> init = NullOpt,
1169  Array<Buffer> alloc_buffers = Array<Buffer>(),
1172  Span span = Span());
1173 
1176 };
1177 
1181 class BlockRealizeNode : public StmtNode {
1182  public:
1192 
1194  v->Visit("iter_values", &iter_values);
1195  v->Visit("predicate", &predicate);
1196  v->Visit("block", &block);
1197  }
1198 
1199  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1200  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1201  equal(block, other->block);
1202  }
1203 
1204  void SHashReduce(SHashReducer hash_reduce) const {
1205  hash_reduce(iter_values);
1206  hash_reduce(predicate);
1207  hash_reduce(block);
1208  }
1209 
1210  static constexpr const char* _type_key = "tir.BlockRealize";
1212 };
1213 
1218 class BlockRealize : public Stmt {
1219  public:
1220  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1221  Span span = Span());
1222 
1225 };
1226 
1228 namespace attr {
1229 // The above attr does not pass to ir stage.
1231 constexpr const char* thread_extent = "thread_extent";
1233 constexpr const char* virtual_thread = "virtual_thread";
1235 constexpr const char* coproc_scope = "coproc_scope";
1240 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1242 constexpr const char* volatile_scope = "volatile_scope";
1248 constexpr const char* extern_scope = "extern_scope";
1253 constexpr const char* compute_scope = "compute_scope";
1255 constexpr const char* storage_alignment = "storage_alignment";
1257 constexpr const char* realize_scope = "realize_scope";
1259 constexpr const char* device_id = "device_id";
1261 constexpr const char* device_type = "device_type";
1263 constexpr const char* loop_scope = "loop_scope";
1265 constexpr const char* reduce_scope = "reduce_scope";
1267 constexpr const char* pragma_scope_prefix = "pragma_";
1269 constexpr const char* pragma_import_c = "pragma_import_c";
1271 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1273 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1278 constexpr const char* prefetch_scope = "prefetch_scope";
1282 constexpr const char* double_buffer_scope = "double_buffer_scope";
1286 constexpr const char* double_buffer_write = "double_buffer_write";
1288 constexpr const char* scan_update_scope = "scan_update_scope";
1290 constexpr const char* scan_init_scope = "scan_init_scope";
1297 constexpr const char* buffer_dim_align = "buffer_dim_align";
1299 constexpr const char* buffer_bound = "buffer_bound";
1309 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1310 // Pipeline related attributes
1312 constexpr const char* channel_read_scope = "channel_read_scope";
1314 constexpr const char* channel_read_advance = "channel_read_advance";
1316 constexpr const char* channel_write_scope = "channel_write_scope";
1318 constexpr const char* channel_write_advance = "channel_write_advance";
1320 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1322 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1323 
1327 constexpr const char* device_scope = "device_scope";
1328 
1332 constexpr const char* fragment_shape = "fragment_shape";
1333 
1337 constexpr const char* fragment_layout = "fragment_layout";
1338 
1342 constexpr const char* hand_threaded = "hand_threaded";
1343 
1351 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1352 
1356 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1357 
1363 inline bool IsPragmaKey(const std::string& attr_key) {
1364  return attr_key.compare(0, 7, "pragma_") == 0;
1365 }
1366 
1367 } // namespace attr
1374 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1375 
1376 // overload printing of for type.
1377 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1378 
1379 // inline implementations
1380 inline const char* ForKind2String(ForKind t) {
1381  switch (t) {
1382  case ForKind::kSerial:
1383  return "serial";
1384  case ForKind::kParallel:
1385  return "parallel";
1386  case ForKind::kVectorized:
1387  return "vectorized";
1388  case ForKind::kUnrolled:
1389  return "unroll";
1391  return "thread_binding";
1392  }
1393  LOG(FATAL) << "Unknown ForKind" << t;
1394  return "Unknown";
1395 }
1396 
1397 } // namespace tir
1398 } // namespace tvm
1399 #endif // TVM_TIR_STMT_H_
tvm::Span Span
Definition: base.h:65
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:904
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1269
Managed reference to StoreNode.
Definition: stmt.h:268
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:424
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
PrimExpr index
The index locations to be stored.
Definition: stmt.h:236
String attr_key
the type key of the attribute
Definition: stmt.h:122
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1322
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:403
Buffer buffer
The buffer variable.
Definition: stmt.h:289
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:741
constexpr const char * device_type
The device type.
Definition: stmt.h:1261
Buffer buffer
The function to be prefetched.
Definition: stmt.h:934
A prefetch hint for a buffer.
Definition: stmt.h:931
Base node of all statements.
Definition: stmt.h:38
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1290
Managed reference to BlockNode.
Definition: stmt.h:1164
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:948
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1118
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:293
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:634
PrimExpr value
The value to be stored.
Definition: stmt.h:234
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:128
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1181
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1314
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1257
PrimExpr value
The value to be binded.
Definition: stmt.h:70
Stmt body
The body of realization.
Definition: stmt.h:466
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to...
Definition: stmt.h:825
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:185
Managed reference to PrefetchNode.
Definition: stmt.h:965
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:253
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset.
Definition: stmt.h:1297
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:412
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:659
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:50
constexpr const char * coproc_scope
Mark region is processed by a co-proccesor.
Definition: stmt.h:1235
Managed reference to ProducerRealizeNode.
Definition: stmt.h:501
Managed reference to IfThenElseNode.
Definition: stmt.h:724
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:348
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1299
StmtNode(Span span)
Definition: stmt.h:47
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:302
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:592
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt...
Definition: stmt.h:1309
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
bool SEqualReduce(const StoreNode *other, SEqualReducer equal) const
Definition: stmt.h:248
a named variable in TIR
Definition: var.h:88
Evaluate(int value, Span span=Span())
Definition: stmt.h:766
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:178
Var var
The variable.
Definition: stmt.h:68
IfThenElse statment.
Definition: stmt.h:689
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:705
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:898
DataProducer producer
The producer that produces the data.
Definition: stmt.h:460
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1253
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1312
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:991
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1038
ForKind kind
The kind of the for loop.
Definition: stmt.h:818
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1271
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1363
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:542
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:836
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:360
Managed reference to ForNode.
Definition: stmt.h:871
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:470
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1100
Block block
The block to be realized.
Definition: stmt.h:1191
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1102
static constexpr const char * _type_key
Definition: stmt.h:49
Managed reference to BufferRegionNode.
Definition: stmt.h:1006
base class of all object containers.
Definition: object.h:165
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing...
Definition: stmt.h:1351
int32_t constant_allocation_size() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:562
Stmt body
The body statement to be executed.
Definition: stmt.h:126
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:595
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1248
PrimExpr value
The value to be stored.
Definition: stmt.h:408
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:464
void operator()(size_t i, const Stmt &stmt) const
Definition: stmt.h:663
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:978
Var buffer_var
The buffer variable.
Definition: stmt.h:515
Managed reference to AssertStmtNode.
Definition: stmt.h:204
DataProducer producer
The producer to store the results into.
Definition: stmt.h:406
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:982
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1259
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:81
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:975
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:834
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:410
Buffer buffer
The buffer variable.
Definition: stmt.h:344
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:944
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
size_t size() const
Definition: stmt.h:630
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:52
ObjectRef node
this is attribute about certain node
Definition: stmt.h:120
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:468
DataType dtype
The type of the buffer.
Definition: stmt.h:517
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:457
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
default semantics – serial execution.
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:521
Definition: span.h:115
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1318
PrimExpr predicate
The predicate to mask which lanes would be stored.
Definition: stmt.h:238
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1327
Parallel execution on CPU.
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:652
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1204
BufferRegion source
The source buffer region.
Definition: stmt.h:1043
size_t size() const
Definition: array.h:399
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:980
bool defined() const
Definition: object.h:537
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:295
String name_hint
The name_hint of the block.
Definition: stmt.h:1106
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:814
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent...
Definition: stmt.h:1240
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:847
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:419
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true...
Definition: stmt.h:1189
PrimExpr condition
The condition.
Definition: stmt.h:692
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1337
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1265
A While loop.
Definition: stmt.h:891
Stmt body
The body of the for loop.
Definition: stmt.h:820
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:613
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:141
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1278
Stmt body
The body of the while loop.
Definition: stmt.h:896
Stmt body
The body of realization.
Definition: stmt.h:350
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1242
Container of all statements.
Definition: stmt.h:57
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:176
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1045
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:604
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:51
The loop variable is bound to a thread in an environment. In the final stage of lowering, the loop is simply removed and the loop variable is mapped to the corresponding context thread.
Reference to string objects.
Definition: string.h:129
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1282
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1193
Managed reference to BufferRealizeNode.
Definition: stmt.h:385
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
Allocate a buffer that can be used in body.
Definition: stmt.h:512
The loop is vectorized.
Definition: var.h:230
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:698
size_t size() const
Definition: stmt.h:598
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:738
The execution is unrolled.
Definition: var.h:226
PrimExpr condition
The termination condition.
Definition: stmt.h:894
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1054
Managed reference to AllocateNode.
Definition: stmt.h:579
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:908
A block is a basic schedule unit in TIR.
Definition: stmt.h:1097
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:240
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1231
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:853
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1120
Managed reference to BlockRealizeNode.
Definition: stmt.h:1218
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1069
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
ForKind
The kind of the loop.
Definition: stmt.h:778
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:748
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:373
Base class of all object reference.
Definition: object.h:504
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:778
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:752
Store value to the buffer.
Definition: stmt.h:229
Managed reference to DataProducerNode.
Definition: buffer.h:260
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:365
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:710
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:143
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1050
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:938
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:124
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1145
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:548
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
Store value to the high dimension buffer.
Definition: stmt.h:286
Managed reference to BufferStoreNode.
Definition: stmt.h:321
Region bounds
Bounds to be realized.
Definition: stmt.h:462
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:694
Managed reference to WhileNode.
Definition: stmt.h:921
void operator()(size_t i, const T &seq) const
Definition: stmt.h:673
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:166
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:479
Managed reference to ProducerStoreNode.
Definition: stmt.h:438
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:352
PrimExpr condition
Condition to be checked.
Definition: stmt.h:169
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1199
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1286
Stmt body
The body to be executed.
Definition: stmt.h:523
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn&#39;t need syncs inserted.
Definition: stmt.h:1342
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1104
Stmt body
The body block.
Definition: stmt.h:72
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1263
Stmt body
The body of the block.
Definition: stmt.h:1108
Stmt else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:696
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1320
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:136
A for loop, with poissible type annotations.
Definition: stmt.h:809
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1184
Var loop_var
The loop variable.
Definition: stmt.h:812
std::ostream & operator<<(std::ostream &os, ForKind kind)
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:307
Flattener(Array< Stmt > *seq)
Definition: stmt.h:661
PrimExpr value
The value to be stored.
Definition: stmt.h:291
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1136
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:86
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1316
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:954
Let binding, bind var to value, then run body.
Definition: stmt.h:65
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:171
Reference to PrimExprNode.
Definition: expr.h:109
Sequence statement.
Definition: stmt.h:620
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:532
constexpr runtime::NullOptType NullOpt
Definition: optional.h:155
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:987
Managed reference to EvaluateNode.
Definition: stmt.h:762
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1332
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:816
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
const char * ForKind2String(ForKind t)
Definition: stmt.h:1380
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1267
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:346
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1288
Buffer buffer
The target buffer.
Definition: stmt.h:1041
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:341
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1124
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:485
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:602
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:190
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:74
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:721
Managed reference to AttrStmtNode.
Definition: stmt.h:156
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:936
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1122
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1273
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block...
Definition: stmt.h:1116
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:530
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:519
constexpr const char * storage_alignment
Mark storage alignement requirement of buffers.
Definition: stmt.h:1255
Managed reference to LetStmtNode.
Definition: stmt.h:100
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:743
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:178
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:609
Var buffer_var
The buffer variable.
Definition: stmt.h:232
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1233
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1356