tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 namespace tvm {
35 namespace tir {
36 
38 class StmtNode : public Object {
39  public:
44  mutable Span span;
45 
46  StmtNode() = default;
47  explicit StmtNode(Span span) : span(span) {}
48 
50 
51  static constexpr const char* _type_key = "tir.Stmt";
52  static constexpr const bool _type_has_method_sequal_reduce = true;
53  static constexpr const bool _type_has_method_shash_reduce = true;
54  static constexpr const uint32_t _type_child_slots = 15;
56 };
57 
59 class Stmt : public ObjectRef {
60  public:
62 };
63 
67 class LetStmtNode : public StmtNode {
68  public:
75 
77  v->Visit("var", &var);
78  v->Visit("value", &value);
79  v->Visit("body", &body);
80  v->Visit("span", &span);
81  }
82 
83  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
84  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
85  equal(body, other->body);
86  }
87 
88  void SHashReduce(SHashReducer hash_reduce) const {
89  hash_reduce.DefHash(var);
90  hash_reduce(value);
91  hash_reduce(body);
92  }
93 
94  static constexpr const char* _type_key = "tir.LetStmt";
96 };
97 
102 class LetStmt : public Stmt {
103  public:
104  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
105 
108 };
109 
120 class AttrStmtNode : public StmtNode {
121  public:
130 
132  v->Visit("node", &node);
133  v->Visit("attr_key", &attr_key);
134  v->Visit("value", &value);
135  v->Visit("body", &body);
136  v->Visit("span", &span);
137  }
138 
139  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
140  return equal(node, other->node) && equal(attr_key, other->attr_key) &&
141  equal(value, other->value) && equal(body, other->body);
142  }
143 
144  void SHashReduce(SHashReducer hash_reduce) const {
145  hash_reduce(node);
146  hash_reduce(attr_key);
147  hash_reduce(value);
148  hash_reduce(body);
149  }
150 
151  static constexpr const char* _type_key = "tir.AttrStmt";
153 };
154 
159 class AttrStmt : public Stmt {
160  public:
161  TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
162 
165 };
166 
170 class AssertStmtNode : public StmtNode {
171  public:
181 
183  v->Visit("condition", &condition);
184  v->Visit("message", &message);
185  v->Visit("body", &body);
186  v->Visit("span", &span);
187  }
188 
189  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
190  return equal(condition, other->condition) && equal(message, other->message) &&
191  equal(body, other->body);
192  }
193 
194  void SHashReduce(SHashReducer hash_reduce) const {
195  hash_reduce(condition);
196  hash_reduce(message);
197  hash_reduce(body);
198  }
199 
200  static constexpr const char* _type_key = "tir.AssertStmt";
202 };
203 
208 class AssertStmt : public Stmt {
209  public:
210  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
211 
214 };
215 
226 class BufferStoreNode : public StmtNode {
227  public:
234 
236  v->Visit("buffer", &buffer);
237  v->Visit("value", &value);
238  v->Visit("indices", &indices);
239  v->Visit("span", &span);
240  }
241 
242  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
243  return equal(buffer, other->buffer) && equal(value, other->value) &&
244  equal(indices, other->indices);
245  }
246 
247  void SHashReduce(SHashReducer hash_reduce) const {
248  hash_reduce(buffer);
249  hash_reduce(value);
250  hash_reduce(indices);
251  }
252 
253  static constexpr const char* _type_key = "tir.BufferStore";
255 };
256 
261 class BufferStore : public Stmt {
262  public:
263  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
264  Span span = Span());
265 
268 };
269 
281 class BufferRealizeNode : public StmtNode {
282  public:
291 
293  v->Visit("buffer", &buffer);
294  v->Visit("bounds", &bounds);
295  v->Visit("condition", &condition);
296  v->Visit("body", &body);
297  v->Visit("span", &span);
298  }
299 
300  bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
301  return equal(buffer, other->buffer) && equal(bounds, other->bounds) &&
302  equal(condition, other->condition) && equal(body, other->body);
303  }
304 
305  void SHashReduce(SHashReducer hash_reduce) const {
306  hash_reduce(buffer);
307  hash_reduce(bounds);
308  hash_reduce(condition);
309  hash_reduce(body);
310  }
311 
312  BufferRealizeNode() = default;
313  BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
314  Span span = Span())
315  : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {}
316 
317  static constexpr const char* _type_key = "tir.BufferRealize";
319 };
320 
325 class BufferRealize : public Stmt {
326  public:
327  TVM_DLL explicit BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
328  Span span = Span());
329 
332 };
333 
344 class ProducerStoreNode : public StmtNode {
345  public:
352 
354  v->Visit("producer", &producer);
355  v->Visit("value", &value);
356  v->Visit("indices", &indices);
357  v->Visit("span", &span);
358  }
359 
360  bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
361  return equal(producer, other->producer) && equal(value, other->value) &&
362  equal(indices, other->indices);
363  }
364 
365  void SHashReduce(SHashReducer hash_reduce) const {
366  hash_reduce(producer);
367  hash_reduce(value);
368  hash_reduce(indices);
369  }
370 
371  static constexpr const char* _type_key = "tir.ProducerStore";
373 };
374 
379 class ProducerStore : public Stmt {
380  public:
381  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
382  Span span = Span());
383 
386 };
387 
400  public:
411 
413  v->Visit("producer", &producer);
414  v->Visit("bounds", &bounds);
415  v->Visit("condition", &condition);
416  v->Visit("body", &body);
417  v->Visit("storage_scope", &storage_scope);
418  v->Visit("span", &span);
419  }
420 
422  return equal(producer, other->producer) && equal(bounds, other->bounds) &&
423  equal(condition, other->condition) && equal(body, other->body) &&
424  equal(storage_scope, other->storage_scope);
425  }
426 
427  void SHashReduce(SHashReducer hash_reduce) const {
428  hash_reduce(producer);
429  hash_reduce(bounds);
430  hash_reduce(condition);
431  hash_reduce(body);
432  hash_reduce(storage_scope);
433  }
434 
435  static constexpr const char* _type_key = "tir.ProducerRealize";
437 };
438 
443 class ProducerRealize : public Stmt {
444  public:
445  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
446  String storage_scope = "", Span span = Span());
447 
450 };
451 
455 class AllocateNode : public StmtNode {
456  public:
474 
476  v->Visit("buffer_var", &buffer_var);
477  v->Visit("dtype", &dtype);
478  v->Visit("extents", &extents);
479  v->Visit("condition", &condition);
480  v->Visit("body", &body);
481  v->Visit("annotations", &annotations);
482  v->Visit("span", &span);
483  }
484 
485  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
486  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
487  equal(extents, other->extents) && equal(condition, other->condition) &&
488  equal(body, other->body) && equal(annotations, other->annotations);
489  }
490 
491  void SHashReduce(SHashReducer hash_reduce) const {
492  hash_reduce.DefHash(buffer_var);
493  hash_reduce(dtype);
494  hash_reduce(extents);
495  hash_reduce(condition);
496  hash_reduce(body);
497  hash_reduce(annotations);
498  }
499 
505  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
512  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
513 
514  static constexpr const char* _type_key = "tir.Allocate";
515  static constexpr const bool _type_has_method_sequal_reduce = true;
516  static constexpr const bool _type_has_method_shash_reduce = true;
518 };
519 
524 class Allocate : public Stmt {
525  public:
526  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
527  Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
528  Span span = Span());
529 
532 };
533 
537 class AllocateConstNode : public StmtNode {
538  public:
562 
564  v->Visit("buffer_var", &buffer_var);
565  v->Visit("data", &data);
566  v->Visit("irmod_storage_idx", &irmod_storage_idx);
567  v->Visit("dtype", &dtype);
568  v->Visit("extents", &extents);
569  v->Visit("body", &body);
570  v->Visit("annotations", &annotations);
571  v->Visit("span", &span);
572  }
573 
574  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
575  return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
576  equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
577  equal(annotations, other->annotations);
578  }
579 
580  void SHashReduce(SHashReducer hash_reduce) const {
581  hash_reduce.DefHash(buffer_var);
582  hash_reduce(dtype);
583  hash_reduce(extents);
584  hash_reduce(body);
585  hash_reduce(annotations);
586  hash_reduce(data);
587  }
588 
594  int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
601  TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
602 
603  static constexpr const char* _type_key = "tir.AllocateConst";
604  static constexpr const bool _type_has_method_sequal_reduce = true;
605  static constexpr const bool _type_has_method_shash_reduce = true;
607 };
608 
613 class AllocateConst : public Stmt {
614  public:
615  /* The constructor to create a IRNode with constant data
616  * depending on the type of ObjectRef, it will either
617  * create AllocateConstNode with irmod_storage_idx or data
618  */
619  TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
620  ObjectRef data_or_idx, Stmt body,
622  Span span = Span());
625 };
626 
628 class DeclBufferNode : public StmtNode {
629  public:
634 
636  v->Visit("buffer", &buffer);
637  v->Visit("body", &body);
638  v->Visit("span", &span);
639  }
640 
641  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
642  return equal(buffer, other->buffer) && equal(body, other->body);
643  }
644 
645  void SHashReduce(SHashReducer hash_reduce) const {
646  hash_reduce(buffer);
647  hash_reduce(body);
648  }
649 
650  static constexpr const char* _type_key = "tir.DeclBuffer";
652 };
653 
655 class DeclBuffer : public Stmt {
656  public:
657  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
660 };
661 
666 class SeqStmtNode : public StmtNode {
667  public:
670 
672  size_t size() const { return seq.size(); }
676  Stmt operator[](size_t index) const { return seq[index]; }
677 
679  v->Visit("seq", &seq);
680  v->Visit("span", &span);
681  }
682 
683  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
684  return equal(seq, other->seq);
685  }
686 
687  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
688 
689  static constexpr const char* _type_key = "tir.SeqStmt";
691 };
692 
694 class SeqStmt : public Stmt {
695  public:
701  TVM_DLL explicit SeqStmt(Array<Stmt> seq, Span span = Span());
702 
704  size_t size() const { return operator->()->size(); }
708  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
725  template <typename... Args>
726  static Stmt Flatten(Args&&... seq_args) {
727  Array<Stmt> seq;
728  runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
729  if (seq.size() == 1) return seq[0];
730  return SeqStmt(seq);
731  }
733  class Flattener {
734  public:
735  explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}
736 
737  template <typename T>
738  void operator()(size_t i, const T& stmt_or_seq) const {
739  if constexpr (std::is_base_of_v<ObjectRef, T>) {
740  // Early bail-out, applicable to any ObjectRef
741  if (!stmt_or_seq.defined()) return;
742  }
743 
744  if constexpr (std::is_same_v<T, SeqStmt>) {
745  // No need for dynamic type-checking if the static type is a
746  // SeqStmt.
747  (*this)(0, stmt_or_seq->seq);
748  } else if constexpr (std::is_base_of_v<T, SeqStmt>) {
749  // Dynamic type-checking for a SeqStmt that could be
750  // flattened.
751  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
752  operator()(0, op->seq);
753  } else {
754  seq_->push_back(stmt_or_seq);
755  }
756  } else if constexpr (std::is_base_of_v<Stmt, T>) {
757  // Any other Stmt type just gets appended.
758  seq_->push_back(stmt_or_seq);
759  } else {
760  // Anything else is treated as an iterable of Stmt.
761  for (auto v : stmt_or_seq) {
762  this->operator()(0, v);
763  }
764  }
765  }
766 
767  private:
768  Array<Stmt>* seq_;
769  };
770 
773 };
774 
778 class IfThenElseNode : public StmtNode {
779  public:
786 
788  v->Visit("condition", &condition);
789  v->Visit("then_case", &then_case);
790  v->Visit("else_case", &else_case);
791  v->Visit("span", &span);
792  }
793 
794  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
795  return equal(condition, other->condition) && equal(then_case, other->then_case) &&
796  equal(else_case, other->else_case);
797  }
798 
799  void SHashReduce(SHashReducer hash_reduce) const {
800  hash_reduce(condition);
801  hash_reduce(then_case);
802  hash_reduce(else_case);
803  }
804 
805  static constexpr const char* _type_key = "tir.IfThenElse";
807 };
808 
813 class IfThenElse : public Stmt {
814  public:
815  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
816  Span span = Span());
817 
820 };
821 
828 class EvaluateNode : public StmtNode {
829  public:
832 
834  v->Visit("value", &value);
835  v->Visit("span", &span);
836  }
837 
838  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
839  return equal(value, other->value);
840  }
841 
842  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
843 
844  static constexpr const char* _type_key = "tir.Evaluate";
846 };
847 
852 class Evaluate : public Stmt {
853  public:
854  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
855 
856  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
857 
860 };
861 
869 enum class ForKind : int {
871  kSerial = 0,
873  kParallel = 1,
878  kVectorized = 2,
880  kUnrolled = 3,
887  kThreadBinding = 4
888 };
889 
900 class ForNode : public StmtNode {
901  public:
926 
928  v->Visit("loop_var", &loop_var);
929  v->Visit("min", &min);
930  v->Visit("extent", &extent);
931  v->Visit("kind", &kind);
932  v->Visit("body", &body);
933  v->Visit("thread_binding", &thread_binding);
934  v->Visit("annotations", &annotations);
935  v->Visit("span", &span);
936  }
937 
938  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
939  return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
940  equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
941  equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
942  }
943 
944  void SHashReduce(SHashReducer hash_reduce) const {
945  hash_reduce.DefHash(loop_var);
946  hash_reduce(min);
947  hash_reduce(extent);
948  hash_reduce(kind);
949  hash_reduce(body);
950  hash_reduce(thread_binding);
951  hash_reduce(annotations);
952  }
953 
954  static constexpr const char* _type_key = "tir.For";
956 };
957 
962 class For : public Stmt {
963  public:
964  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
965  Optional<IterVar> thread_binding = NullOpt,
967 
970 };
971 
982 class WhileNode : public StmtNode {
983  public:
988 
990  v->Visit("condition", &condition);
991  v->Visit("body", &body);
992  v->Visit("span", &span);
993  }
994 
995  bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
996  return equal(condition, other->condition) && equal(body, other->body);
997  }
998 
999  void SHashReduce(SHashReducer hash_reduce) const {
1000  hash_reduce(condition);
1001  hash_reduce(body);
1002  }
1003 
1004  static constexpr const char* _type_key = "tir.While";
1006 };
1007 
1012 class While : public Stmt {
1013  public:
1014  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
1015 
1018 };
1019 
1023 class PrefetchNode : public StmtNode {
1024  public:
1029 
1031  v->Visit("buffer", &buffer);
1032  v->Visit("bounds", &bounds);
1033  v->Visit("span", &span);
1034  }
1035 
1036  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1037  return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1038  }
1039 
1040  void SHashReduce(SHashReducer hash_reduce) const {
1041  hash_reduce(buffer);
1042  hash_reduce(bounds);
1043  }
1044 
1045  PrefetchNode() = default;
1047  : StmtNode(span), buffer(buffer), bounds(bounds) {}
1048 
1049  static constexpr const char* _type_key = "tir.Prefetch";
1051 };
1052 
1057 class Prefetch : public Stmt {
1058  public:
1059  TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1060 
1063 };
1064 
1068 class BufferRegionNode : public Object {
1069  public:
1074 
1076  v->Visit("buffer", &buffer);
1077  v->Visit("region", &region);
1078  }
1079 
1080  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
1081  return equal(buffer, other->buffer) && equal(region, other->region);
1082  }
1083 
1084  void SHashReduce(SHashReducer hash_reduce) const {
1085  hash_reduce(buffer);
1086  hash_reduce(region);
1087  }
1088 
1089  static constexpr const char* _type_key = "tir.BufferRegion";
1090  static constexpr const bool _type_has_method_sequal_reduce = true;
1091  static constexpr const bool _type_has_method_shash_reduce = true;
1093 };
1094 
1099 class BufferRegion : public ObjectRef {
1100  public:
1101  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
1102 
1108  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
1109 
1116  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
1117 
1120 };
1121 
1132  public:
1137 
1139  v->Visit("buffer", &buffer);
1140  v->Visit("source", &source);
1141  }
1142 
1144  return equal(buffer, other->buffer) && equal(source, other->source);
1145  }
1146 
1147  void SHashReduce(SHashReducer hash_reduce) const {
1148  hash_reduce(buffer);
1149  hash_reduce(source);
1150  }
1151 
1152  static constexpr const char* _type_key = "tir.MatchBufferRegion";
1153  static constexpr const bool _type_has_method_sequal_reduce = true;
1154  static constexpr const bool _type_has_method_shash_reduce = true;
1156 };
1157 
1163  public:
1164  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
1165 
1168 };
1169 
1191 class BlockNode : public StmtNode {
1192  public:
1217 
1219  v->Visit("iter_vars", &iter_vars);
1220  v->Visit("reads", &reads);
1221  v->Visit("writes", &writes);
1222  v->Visit("name_hint", &name_hint);
1223  v->Visit("body", &body);
1224  v->Visit("init", &init);
1225  v->Visit("alloc_buffers", &alloc_buffers);
1226  v->Visit("match_buffers", &match_buffers);
1227  v->Visit("annotations", &annotations);
1228  }
1229 
1230  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
1231  // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
1232  return equal.DefEqual(iter_vars, other->iter_vars) &&
1233  equal(alloc_buffers, other->alloc_buffers) &&
1234  equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
1235  equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
1236  equal(annotations, other->annotations);
1237  }
1238 
1239  void SHashReduce(SHashReducer hash_reduce) const {
1240  hash_reduce.DefHash(iter_vars);
1241  hash_reduce(alloc_buffers);
1242  hash_reduce(match_buffers);
1243  hash_reduce(reads);
1244  hash_reduce(writes);
1245  hash_reduce(body);
1246  hash_reduce(init);
1247  hash_reduce(annotations);
1248  }
1249 
1250  static constexpr const char* _type_key = "tir.Block";
1252 };
1253 
1258 class Block : public Stmt {
1259  public:
1260  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
1261  Array<BufferRegion> writes, String name_hint, Stmt body,
1262  Optional<Stmt> init = NullOpt,
1263  Array<Buffer> alloc_buffers = Array<Buffer>(),
1266  Span span = Span());
1267 
1270 };
1271 
1275 class BlockRealizeNode : public StmtNode {
1276  public:
1286 
1288  v->Visit("iter_values", &iter_values);
1289  v->Visit("predicate", &predicate);
1290  v->Visit("block", &block);
1291  }
1292 
1293  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
1294  return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
1295  equal(block, other->block);
1296  }
1297 
1298  void SHashReduce(SHashReducer hash_reduce) const {
1299  hash_reduce(iter_values);
1300  hash_reduce(predicate);
1301  hash_reduce(block);
1302  }
1303 
1304  static constexpr const char* _type_key = "tir.BlockRealize";
1306 };
1307 
1312 class BlockRealize : public Stmt {
1313  public:
1314  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
1315  Span span = Span());
1316 
1319 };
1320 
1322 namespace attr {
1323 // The above attr does not pass to ir stage.
1325 constexpr const char* thread_extent = "thread_extent";
1327 constexpr const char* virtual_thread = "virtual_thread";
1329 constexpr const char* coproc_scope = "coproc_scope";
1334 constexpr const char* coproc_uop_scope = "coproc_uop_scope";
1336 constexpr const char* volatile_scope = "volatile_scope";
1342 constexpr const char* extern_scope = "extern_scope";
1347 constexpr const char* compute_scope = "compute_scope";
1349 constexpr const char* storage_alignment = "storage_alignment";
1351 constexpr const char* realize_scope = "realize_scope";
1353 constexpr const char* device_id = "device_id";
1355 constexpr const char* device_type = "device_type";
1357 constexpr const char* loop_scope = "loop_scope";
1359 constexpr const char* reduce_scope = "reduce_scope";
1361 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1363 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
1365 constexpr const char* pragma_scope_prefix = "pragma_";
1367 constexpr const char* pragma_import_c = "pragma_import_c";
1369 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1371 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
1376 constexpr const char* prefetch_scope = "prefetch_scope";
1383 constexpr const char* layout_transforms = "layout_transforms";
1391 constexpr const char* axis_separators = "axis_separators";
1395 constexpr const char* double_buffer_scope = "double_buffer_scope";
1399 constexpr const char* double_buffer_write = "double_buffer_write";
1401 constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
1403 constexpr const char* scan_update_scope = "scan_update_scope";
1405 constexpr const char* scan_init_scope = "scan_init_scope";
1412 constexpr const char* buffer_dim_align = "buffer_dim_align";
1414 constexpr const char* buffer_bound = "buffer_bound";
1424 constexpr const char* buffer_bind_scope = "buffer_bind_scope";
1425 // Pipeline related attributes
1427 constexpr const char* channel_read_scope = "channel_read_scope";
1429 constexpr const char* channel_read_advance = "channel_read_advance";
1431 constexpr const char* channel_write_scope = "channel_write_scope";
1433 constexpr const char* channel_write_advance = "channel_write_advance";
1435 constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
1437 constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
1438 
1442 constexpr const char* device_scope = "device_scope";
1443 
1447 constexpr const char* async_scope = "async_scope";
1448 
1466 constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
1467 constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
1468 constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
1469 
1473 constexpr const char* fragment_shape = "fragment_shape";
1474 
1478 constexpr const char* fragment_layout = "fragment_layout";
1479 
1483 constexpr const char* hand_threaded = "hand_threaded";
1484 
1492 constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1493 
1497 constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1498 
1500 constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1501 
1503 constexpr const char* software_pipeline_order = "software_pipeline_order";
1504 
1509 constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
1510 
1512 constexpr const char* layout_free_buffers = "layout_free_buffers";
1513 
1515 constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
1516 
1518 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1519 
1524 constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1525 
1528  "meta_schedule.thread_extent_low_inclusive";
1529 
1532  "meta_schedule.thread_extent_high_inclusive";
1533 
1536  "meta_schedule.random_compute_producer";
1537 
1539 constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
1540 
1542 constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
1543 
1545 constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
1546 
1548 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
1549 
1551 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1552 
1554 constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
1558 constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
1559 
1563 constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1564 
1566 constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1567 
1574 constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1575 
1577 constexpr const int meta_schedule_cache_type_read = 0;
1578 
1580 constexpr const int meta_schedule_cache_type_write = 1;
1581 
1583 constexpr const char* auto_copy = "auto_copy";
1584 
1586 constexpr const char* local_stage = "local_stage";
1587 
1589 constexpr const char* vector_bytes = "vector_bytes";
1590 
1595 constexpr const char* warp_execution = "warp_execution";
1596 
1598 constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1599 
1605 inline bool IsPragmaKey(const std::string& attr_key) {
1606  return attr_key.compare(0, 7, "pragma_") == 0;
1607 }
1608 
1609 } // namespace attr
1616 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
1617 
1618 // overload printing of for type.
1619 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
1620 
1621 // inline implementations
1622 inline const char* ForKind2String(ForKind t) {
1623  switch (t) {
1624  case ForKind::kSerial:
1625  return "serial";
1626  case ForKind::kParallel:
1627  return "parallel";
1628  case ForKind::kVectorized:
1629  return "vectorized";
1630  case ForKind::kUnrolled:
1631  return "unroll";
1633  return "thread_binding";
1634  }
1635  LOG(FATAL) << "Unknown ForKind" << t;
1636 }
1637 
1638 } // namespace tir
1639 } // namespace tvm
1640 #endif // TVM_TIR_STMT_H_
tvm::Span Span
Definition: base.h:65
bool SEqualReduce(const WhileNode *other, SEqualReducer equal) const
Definition: stmt.h:995
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:1367
Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:785
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:738
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:365
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:120
String attr_key
the type key of the attribute
Definition: stmt.h:125
constexpr const char * pipeline_exec_scope
pipeline execution scope, implies the scope can be pipelined.
Definition: stmt.h:1437
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:344
Buffer buffer
The buffer variable.
Definition: stmt.h:229
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:831
constexpr const char * device_type
The device type.
Definition: stmt.h:1355
Buffer buffer
The function to be prefetched.
Definition: stmt.h:1026
A prefetch hint for a buffer.
Definition: stmt.h:1023
Base node of all statements.
Definition: stmt.h:38
Declare a buffer that can be used in the body.
Definition: stmt.h:628
constexpr const char * scan_init_scope
Mark of scan init scope.
Definition: stmt.h:1405
Managed reference to BlockNode.
Definition: stmt.h:1258
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1040
Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:1212
Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:233
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:708
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:131
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1275
constexpr const char * channel_read_advance
Advance step of channel after end of scope.
Definition: stmt.h:1429
Optional< Integer > irmod_storage_idx
If the PrimFunc containing the Stmt is added to IRModule, this is an optional index to indicate the i...
Definition: stmt.h:548
constexpr const char * realize_scope
Mark storage scope of realization.
Definition: stmt.h:1351
PrimExpr value
The value to be binded.
Definition: stmt.h:72
Stmt body
The body of realization.
Definition: stmt.h:408
constexpr const char * auto_copy
Mark auto copy for memhammer.
Definition: stmt.h:1583
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to...
Definition: stmt.h:916
bool SEqualReduce(const AssertStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:189
Managed reference to PrefetchNode.
Definition: stmt.h:1057
constexpr const char * buffer_dim_align
Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset.
Definition: stmt.h:1412
Managed reference to AllocateConstNode.
Definition: stmt.h:613
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:353
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:733
constexpr const char * meta_schedule_layout_rewrite_preproc
Mark that a block is a preprocessor block for layout rewrite.
Definition: stmt.h:1554
constexpr const char * vector_bytes
Mark vectorization length constraint on block.
Definition: stmt.h:1589
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
static constexpr const bool _type_has_method_sequal_reduce
Definition: stmt.h:52
constexpr const char * coproc_scope
Mark region is processed by a co-proccesor.
Definition: stmt.h:1329
Managed reference to ProducerRealizeNode.
Definition: stmt.h:443
Managed reference to IfThenElseNode.
Definition: stmt.h:813
constexpr const char * meta_schedule_auto_tensorize
Mark that a block should be further rewritten using tensorization.
Definition: stmt.h:1551
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:288
constexpr const char * buffer_bound
Mark stores/loads with theirs bounds.
Definition: stmt.h:1414
constexpr const char * software_pipeline_async_stages
List stages in the software pipeline that should run asynchronously.
Definition: stmt.h:1509
StmtNode(Span span)
Definition: stmt.h:47
bool SEqualReduce(const BufferStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:242
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:666
constexpr const char * buffer_bind_scope
Bind the buffer specification to the region of the op When this scope occurs, the stmt...
Definition: stmt.h:1424
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
constexpr const char * meta_schedule_unroll_explicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1545
a named variable in TIR
Definition: var.h:88
Evaluate(int value, Span span=Span())
Definition: stmt.h:856
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:182
Var var
The variable.
Definition: stmt.h:70
IfThenElse statment.
Definition: stmt.h:778
Buffer buffer
The buffer being declared.
Definition: stmt.h:631
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:552
bool SEqualReduce(const IfThenElseNode *other, SEqualReducer equal) const
Definition: stmt.h:794
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:989
DataProducer producer
The producer that produces the data.
Definition: stmt.h:402
constexpr const char * async_commit_queue_scope
Annotations for invoking and synchronizing asynchronous operations.
Definition: stmt.h:1466
constexpr const char * compute_scope
Mark the scope as when computation start to happen This can hint some code generator to create a new ...
Definition: stmt.h:1347
constexpr const char * channel_read_scope
channel read scope
Definition: stmt.h:1427
constexpr const char * warp_execution
Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size...
Definition: stmt.h:1595
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1084
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:1131
ForKind kind
The kind of the for loop.
Definition: stmt.h:909
constexpr const char * meta_schedule_thread_extent_high_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1531
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:1369
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:1605
constexpr const char * meta_schedule_tensor_core_enabled
Mark that tensor core is enabled in the PrimExpr.
Definition: stmt.h:1566
void Prefetch(Buffer buffer, Array< Range > bounds)
The prefetch hint for a buffer.
bool SEqualReduce(const AllocateNode *other, SEqualReducer equal) const
Definition: stmt.h:485
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:927
bool SEqualReduce(const BufferRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:300
Managed reference to ForNode.
Definition: stmt.h:962
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:412
Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:1194
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
Block block
The block to be realized.
Definition: stmt.h:1285
DataType dtype
The type of the buffer.
Definition: stmt.h:550
Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:1196
static constexpr const char * _type_key
Definition: stmt.h:51
Managed reference to BufferRegionNode.
Definition: stmt.h:1099
base class of all object containers.
Definition: object.h:167
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array< PrimExpr > extents, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate constant node.
constexpr const char * script_parsing_detect_access
Mark whether the script-completer need to fill in missing access region during script parsing...
Definition: stmt.h:1492
constexpr const char * meta_schedule_auto_tensorize_init
Mark that the init statement of a block should be further rewritten using tensorization.
Definition: stmt.h:1558
Stmt body
The body statement to be executed.
Definition: stmt.h:129
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:669
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we n...
Definition: stmt.h:1342
PrimExpr value
The value to be stored.
Definition: stmt.h:349
PrimExpr condition
Only realize if condition holds.
Definition: stmt.h:406
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:1071
Var buffer_var
The buffer variable.
Definition: stmt.h:458
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:1361
Managed reference to AssertStmtNode.
Definition: stmt.h:208
DataProducer producer
The producer to store the results into.
Definition: stmt.h:347
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1075
constexpr const char * meta_schedule_vectorize
Mark auto-vectorize setting on the block.
Definition: stmt.h:1542
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:1353
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:635
bool SEqualReduce(const LetStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:83
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:1068
Map< String, ObjectRef > annotations
Additional annotations about the loop.
Definition: stmt.h:925
constexpr const char * rolling_buffer_scope
Mark realization for rolling buffer optimization.
Definition: stmt.h:1401
Array< PrimExpr > indices
The index arguments of the function.
Definition: stmt.h:351
Buffer buffer
The buffer variable.
Definition: stmt.h:284
bool SEqualReduce(const PrefetchNode *other, SEqualReducer equal) const
Definition: stmt.h:1036
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
size_t size() const
Definition: stmt.h:704
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:54
ObjectRef node
this is attribute about certain node
Definition: stmt.h:123
String storage_scope
The storage scope associated with this realization.
Definition: stmt.h:410
DataType dtype
The type of the buffer.
Definition: stmt.h:460
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:399
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
constexpr const char * layout_free_buffers
Mark the buffers which is const access and can be transformed layout.
Definition: stmt.h:1512
constexpr const char * async_wait_queue_scope
Definition: stmt.h:1467
default semantics – serial execution.
PrimExpr condition
Only allocate buffer when condition is satisfied.
Definition: stmt.h:464
AllocateFrame Allocate(Array< PrimExpr > extents, DataType dtype, String storage_scope="", Optional< PrimExpr > condition=NullOpt, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate node.
Definition: source_map.h:120
constexpr const char * channel_write_advance
Advance step of channel after end of scope.
Definition: stmt.h:1433
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:1442
Parallel execution on CPU.
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:726
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1298
BufferRegion source
The source buffer region.
Definition: stmt.h:1136
size_t size() const
Definition: array.h:420
Array< Range > region
The region array of the buffer region.
Definition: stmt.h:1073
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
Stmt body
The body to be executed.
Definition: stmt.h:554
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:235
String name_hint
The name_hint of the block.
Definition: stmt.h:1200
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:905
constexpr const char * coproc_uop_scope
Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent...
Definition: stmt.h:1334
bool SEqualReduce(const ForNode *other, SEqualReducer equal) const
Definition: stmt.h:938
bool SEqualReduce(const ProducerStoreNode *other, SEqualReducer equal) const
Definition: stmt.h:360
constexpr const char * require_block_var_bound_predicate
Mark that the block need to add predicate for block var bounds during lowering.
Definition: stmt.h:1563
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true...
Definition: stmt.h:1283
PrimExpr condition
The condition.
Definition: stmt.h:781
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:563
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
constexpr const char * fragment_layout
Mark that the layout of TensorCore fragment.
Definition: stmt.h:1478
constexpr const char * reduce_scope
Mark of reduce scope.
Definition: stmt.h:1359
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:505
A While loop.
Definition: stmt.h:982
Stmt body
The body of the for loop.
Definition: stmt.h:911
constexpr const char * manifest_shared_memory_local_stage
Mark the local stage for the shared memory access should be added.
Definition: stmt.h:1515
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:687
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:144
constexpr const char * meta_schedule_thread_extent_low_inclusive
The allowed range of thread extent in thread bindings.
Definition: stmt.h:1527
constexpr const char * prefetch_scope
Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.
Definition: stmt.h:1376
Stmt body
The body of the while loop.
Definition: stmt.h:987
Stmt body
The body of realization.
Definition: stmt.h:290
constexpr const char * layout_transforms
Marks the layout transforms to be used for a tensor.
Definition: stmt.h:1383
constexpr const char * volatile_scope
Mark the scope as volatile access for certain handle.
Definition: stmt.h:1336
Container of all statements.
Definition: stmt.h:59
Stmt body
Body which this assertion holds true. Will be executed after the assertion.
Definition: stmt.h:180
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1138
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:678
void BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices)
Store data in a buffer.
WhileFrame While(PrimExpr condition)
Create a while loop.
constexpr const char * meta_schedule_random_compute_producer
Mark the block whose producer needs to be applied by rule Random-Compute-Location.
Definition: stmt.h:1535
static constexpr const bool _type_has_method_shash_reduce
Definition: stmt.h:53
The loop variable is bound to a thread in an environment. In the final stage of lowering, the loop is simply removed and the loop variable is mapped to the corresponding context thread.
Reference to string objects.
Definition: string.h:98
constexpr const int meta_schedule_cache_type_read
Definition: stmt.h:1577
constexpr const char * async_scope
Mark that the attached statement runs asynchronously.
Definition: stmt.h:1447
constexpr const char * double_buffer_scope
Marks production of double buffer data.
Definition: stmt.h:1395
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1287
Managed reference to BufferRealizeNode.
Definition: stmt.h:325
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Allocate a buffer that can be used in body.
Definition: stmt.h:455
The loop is vectorized.
Definition: var.h:237
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:787
size_t size() const
Definition: stmt.h:672
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:828
The execution is unrolled.
Definition: var.h:233
PrimExpr condition
The termination condition.
Definition: stmt.h:985
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:1363
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1147
Managed reference to AllocateNode.
Definition: stmt.h:524
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:999
A block is a basic schedule unit in TIR.
Definition: stmt.h:1191
constexpr const char * meta_schedule_cache_type
Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write...
Definition: stmt.h:1574
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:1325
Optional< runtime::NDArray > data
The optional data associated to the constant.
Definition: stmt.h:543
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:944
Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:1214
Managed reference to BlockRealizeNode.
Definition: stmt.h:1312
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:1162
Managed reference to DeclBufferNode.
Definition: stmt.h:655
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
ForKind
The kind of the loop.
Definition: stmt.h:869
Allocate a buffer that can be used in body.
Definition: stmt.h:537
bool SEqualReduce(const EvaluateNode *other, SEqualReducer equal) const
Definition: stmt.h:838
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:561
LetFrame LetStmt(PrimExpr value, Optional< Type > type_annotation=NullOpt, Optional< Var > var=NullOpt)
The let binding.
BufferRealizeNode(Buffer buffer, Array< Range > bounds, PrimExpr condition, Stmt body, Span span=Span())
Definition: stmt.h:313
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
constexpr const char * software_pipeline_order
Mark the order of a statement in the software pipeline.
Definition: stmt.h:1503
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:842
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:645
bool SEqualReduce(const DeclBufferNode *other, SEqualReducer equal) const
Definition: stmt.h:641
constexpr const int meta_schedule_cache_type_write
Definition: stmt.h:1580
Managed reference to DataProducerNode.
Definition: buffer.h:293
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1391
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:305
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:799
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
bool SEqualReduce(const MatchBufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1143
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1030
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:127
constexpr const char * async_wait_inflight_count
Definition: stmt.h:1468
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:1239
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:491
constexpr const char * software_pipeline_stage
Mark the stage of a statement in the software pipeline.
Definition: stmt.h:1500
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:44
Store value to the high dimension buffer.
Definition: stmt.h:226
Managed reference to BufferStoreNode.
Definition: stmt.h:261
Region bounds
Bounds to be realized.
Definition: stmt.h:404
Var buffer_var
The buffer variable.
Definition: stmt.h:540
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:783
Managed reference to WhileNode.
Definition: stmt.h:1012
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:170
bool SEqualReduce(const ProducerRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:421
Managed reference to ProducerStoreNode.
Definition: stmt.h:379
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:292
PrimExpr condition
Condition to be checked.
Definition: stmt.h:173
constexpr const char * meta_schedule_parallel
Mark auto-parallel setting on the block.
Definition: stmt.h:1539
bool SEqualReduce(const BlockRealizeNode *other, SEqualReducer equal) const
Definition: stmt.h:1293
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
constexpr const char * double_buffer_write
Marks region used by double buffer write.
Definition: stmt.h:1399
Stmt body
The body to be executed.
Definition: stmt.h:466
DeclBufferFrame DeclBuffer(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration frame.
constexpr const char * hand_threaded
Mark that the kernel is hand threaded and doesn&#39;t need syncs inserted.
Definition: stmt.h:1483
constexpr const char * meta_schedule_unroll_implicit
Mark auto-unroll setting on the block.
Definition: stmt.h:1548
Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:1198
Stmt body
The body block.
Definition: stmt.h:74
constexpr const char * loop_scope
Mark of loop scope.
Definition: stmt.h:1357
Stmt body
The body of the block.
Definition: stmt.h:1202
constexpr const char * pipeline_stage_scope
pipeline stage scope, implies always execution
Definition: stmt.h:1435
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
bool SEqualReduce(const AttrStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:139
A for loop, with poissible type annotations.
Definition: stmt.h:900
Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:1278
Var loop_var
The loop variable.
Definition: stmt.h:903
void Evaluate(PrimExpr value)
Evaluate the input expression.
bool SEqualReduce(const AllocateConstNode *other, SEqualReducer equal) const
Definition: stmt.h:574
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:247
Flattener(Array< Stmt > *seq)
Definition: stmt.h:735
PrimExpr value
The value to be stored.
Definition: stmt.h:231
bool SEqualReduce(const BlockNode *other, SEqualReducer equal) const
Definition: stmt.h:1230
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:88
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object)
constexpr const char * channel_write_scope
channel write scope
Definition: stmt.h:1431
PrefetchNode(Buffer buffer, Array< Range > bounds, Span span=Span())
Definition: stmt.h:1046
Let binding, bind var to value, then run body.
Definition: stmt.h:67
PrimExpr message
Error message when assertion failed.
Definition: stmt.h:175
Reference to PrimExprNode.
Definition: expr.h:114
Sequence statement.
Definition: stmt.h:694
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:475
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
constexpr const char * meta_schedule_cooperative_fetch
Mark that the loop should be further skip and bound to environment threads to enable cooperative fetc...
Definition: stmt.h:1524
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:580
bool SEqualReduce(const BufferRegionNode *other, SEqualReducer equal) const
Definition: stmt.h:1080
constexpr const char * meta_schedule_tiling_structure
Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.
Definition: stmt.h:1518
Managed reference to EvaluateNode.
Definition: stmt.h:852
constexpr const char * fragment_shape
Mark that the shape of TensorCore fragment.
Definition: stmt.h:1473
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:907
const char * ForKind2String(ForKind t)
Definition: stmt.h:1622
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:1365
constexpr const char * local_stage
Mark local stage constraint on data copy.
Definition: stmt.h:1586
constexpr const char * meta_schedule_inline_rule
Mark that a block is disallowed in auto inline.
Definition: stmt.h:1598
Array< Range > bounds
Bounds to be realized.
Definition: stmt.h:286
constexpr const char * scan_update_scope
Mark of scan update scope.
Definition: stmt.h:1403
Buffer buffer
The target buffer.
Definition: stmt.h:1134
Stmt body
The body to be executed.
Definition: stmt.h:633
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:281
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:1218
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:427
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:676
void SHashReduce(SHashReducer hash_reduce) const
Definition: stmt.h:194
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:76
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Managed reference to AttrStmtNode.
Definition: stmt.h:159
Array< Range > bounds
Bounds to be prefetched.
Definition: stmt.h:1028
Map< String, ObjectRef > annotations
The annotation of the block.
Definition: stmt.h:1216
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:1371
Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block...
Definition: stmt.h:1210
Map< String, ObjectRef > annotations
Additional annotations about the allocation.
Definition: stmt.h:473
Array< PrimExpr > extents
The extents of the buffer.
Definition: stmt.h:462
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:1349
Managed reference to LetStmtNode.
Definition: stmt.h:102
void VisitAttrs(AttrVisitor *v)
Definition: stmt.h:833
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187
bool SEqualReduce(const SeqStmtNode *other, SEqualReducer equal) const
Definition: stmt.h:683
constexpr const char * virtual_thread
Mark launching of a virtual thread.
Definition: stmt.h:1327
int64_t ConstantAllocationSize() const
If the buffer size is constant, return the size. Otherwise return 0.
Definition: stmt.h:594
constexpr const char * pragma_loop_partition_hint
Mark that the loop should be partitioned.
Definition: stmt.h:1497