tvm
stmt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: Many low-level stmts originate from Halide.
24 #ifndef TVM_TIR_STMT_H_
25 #define TVM_TIR_STMT_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tirx/expr.h>
30 
31 #include <optional>
32 #include <string>
33 #include <type_traits>
34 #include <utility>
35 
36 namespace tvm {
37 namespace tirx {
38 
40 class StmtNode : public Object {
41  public:
46  mutable Span span;
47 
48  StmtNode() = default;
49  explicit StmtNode(Span span) : span(span) {}
50 
51  static void RegisterReflection() {
52  namespace refl = tvm::ffi::reflection;
53  refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
54  }
55 
57 
58  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
59 
60  static constexpr const uint32_t _type_child_slots = 15;
61  TVM_FFI_DECLARE_OBJECT_INFO("tirx.Stmt", StmtNode, Object);
62 };
63 
65 class Stmt : public ObjectRef {
66  public:
68 };
69 
77 class BindNode : public StmtNode {
78  public:
83 
84  static void RegisterReflection() {
85  namespace refl = tvm::ffi::reflection;
86  refl::ObjectDef<BindNode>()
87  .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef())
88  .def_ro("value", &BindNode::value);
89  }
91 };
92 
97 class Bind : public Stmt {
98  public:
99  TVM_DLL Bind(Var var, PrimExpr value, Span span = Span());
100 
103 };
104 
115 class AttrStmtNode : public StmtNode {
116  public:
118  ffi::Any node;
120  ffi::String attr_key;
125 
126  static void RegisterReflection() {
127  namespace refl = tvm::ffi::reflection;
128  refl::ObjectDef<AttrStmtNode>()
129  .def_ro("node", &AttrStmtNode::node)
130  .def_ro("attr_key", &AttrStmtNode::attr_key)
131  .def_ro("value", &AttrStmtNode::value)
132  .def_ro("body", &AttrStmtNode::body);
133  }
135 };
136 
141 class AttrStmt : public Stmt {
142  public:
143  TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body,
144  Span span = Span());
145 
148 };
149 
159 class AssertStmtNode : public StmtNode {
160  public:
166  ffi::Array<StringImm> message_parts;
167 
168  static void RegisterReflection() {
169  namespace refl = tvm::ffi::reflection;
170  refl::ObjectDef<AssertStmtNode>()
171  .def_ro("condition", &AssertStmtNode::condition)
172  .def_ro("error_kind", &AssertStmtNode::error_kind)
173  .def_ro("message_parts", &AssertStmtNode::message_parts);
174  }
176 };
177 
182 class AssertStmt : public Stmt {
183  public:
184  TVM_DLL AssertStmt(PrimExpr condition, StringImm error_kind, ffi::Array<StringImm> message_parts,
185  Span span = Span());
186 
189 };
190 
201 class BufferStoreNode : public StmtNode {
202  public:
208  ffi::Array<PrimExpr> indices;
210  ffi::Optional<PrimExpr> predicate;
211 
212  static void RegisterReflection() {
213  namespace refl = tvm::ffi::reflection;
214  refl::ObjectDef<BufferStoreNode>()
215  .def_ro("buffer", &BufferStoreNode::buffer)
216  .def_ro("value", &BufferStoreNode::value)
217  .def_ro("indices", &BufferStoreNode::indices)
218  .def_ro("predicate", &BufferStoreNode::predicate);
219  }
221 };
222 
227 class BufferStore : public Stmt {
228  public:
229  TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
230  ffi::Optional<PrimExpr> predicate = std::nullopt,
231  Span span = Span());
232 
235 };
236 
238 class DeclBufferNode : public StmtNode {
239  public:
242 
243  static void RegisterReflection() {
244  namespace refl = tvm::ffi::reflection;
245  refl::ObjectDef<DeclBufferNode>().def_ro("buffer", &DeclBufferNode::buffer);
246  }
248 };
249 
251 class DeclBuffer : public Stmt {
252  public:
253  TVM_DLL DeclBuffer(Buffer buffer, Span span = Span());
256 };
257 
259 class AllocBufferNode : public StmtNode {
260  public:
269  ffi::Map<ffi::String, ffi::Any> annotations;
270 
271  static void RegisterReflection() {
272  namespace refl = tvm::ffi::reflection;
273  refl::ObjectDef<AllocBufferNode>()
274  .def_ro("buffer", &AllocBufferNode::buffer, refl::AttachFieldFlag::SEqHashDef())
275  .def_ro("annotations", &AllocBufferNode::annotations);
276  }
278 };
279 
281 class AllocBuffer : public Stmt {
282  public:
283  TVM_DLL AllocBuffer(
284  Buffer buffer,
285  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
286  Span span = Span());
291  std::optional<int64_t> ConstantAllocationSize() const {
292  int64_t result = 1;
293  for (const PrimExpr& extent : (*this)->buffer->shape) {
294  if (const auto* int_size = extent.as<IntImmNode>()) {
295  result *= int_size->value;
296  } else {
297  return std::nullopt;
298  }
299  }
300  return result;
301  }
302 
305 };
306 
311 class SeqStmtNode : public StmtNode {
312  public:
314  ffi::Array<Stmt> seq;
315 
317  size_t size() const { return seq.size(); }
321  Stmt operator[](size_t index) const { return seq[index]; }
322 
323  static void RegisterReflection() {
324  namespace refl = tvm::ffi::reflection;
325  refl::ObjectDef<SeqStmtNode>().def_ro("seq", &SeqStmtNode::seq);
326  }
328 };
329 
336 class EvaluateNode : public StmtNode {
337  public:
340 
341  static void RegisterReflection() {
342  namespace refl = tvm::ffi::reflection;
343  refl::ObjectDef<EvaluateNode>().def_ro("value", &EvaluateNode::value);
344  }
346 };
347 
352 class Evaluate : public Stmt {
353  public:
354  TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());
355 
356  explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
357 
360 };
361 
363 class SeqStmt : public Stmt {
364  public:
370  TVM_DLL explicit SeqStmt(ffi::Array<Stmt> seq, Span span = Span());
371 
373  size_t size() const { return operator->()->size(); }
377  Stmt operator[](size_t index) const { return (*(operator->()))[index]; }
398  template <typename... Args>
399  static Stmt Flatten(Args&&... seq_args) {
400  ffi::Array<Stmt> seq;
401 
402  ffi::details::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
403 
404  if (seq.empty()) {
405  return Evaluate(0);
406  } else if (seq.size() == 1) {
407  return seq[0];
408  }
409 
410  // If the argument is a single SeqStmt argument with no
411  // flattening or unwrapping required, then we may
412  // return the SeqStmt as-is.
413  if constexpr (sizeof...(seq_args) == 1) {
414  if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
415  SeqStmt original = opt.value();
416  bool all_same = [&]() {
417  if (original->seq.size() != seq.size()) {
418  return false;
419  }
420  for (size_t i = 0; i < seq.size(); i++) {
421  if (!original->seq[i].same_as(seq[i])) {
422  return false;
423  }
424  }
425  return true;
426  }();
427  if (all_same) {
428  return original;
429  }
430  }
431  }
432 
433  return SeqStmt(seq);
434  }
436  class Flattener {
437  public:
438  explicit Flattener(ffi::Array<Stmt>* seq) : seq_(seq) {}
439 
440  template <typename T>
441  static ffi::Optional<SeqStmt> AsSeqStmt(const T& t) {
442  if constexpr (std::is_same_v<T, SeqStmt>) {
443  return t;
444  }
445  if constexpr (!std::is_base_of_v<T, SeqStmt>) {
446  return std::nullopt;
447  }
448  if constexpr (std::is_base_of_v<Stmt, T>) {
449  if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) {
450  return ffi::GetRef<SeqStmt>(ptr);
451  } else {
452  return std::nullopt;
453  }
454  }
455  return std::nullopt;
456  }
457 
458  template <typename T>
459  void operator()(size_t i, const T& stmt_or_seq) const {
460  if constexpr (std::is_base_of_v<ObjectRef, T>) {
461  // Early bail-out, applicable to any ObjectRef
462  if (!stmt_or_seq.defined()) {
463  return;
464  }
465  }
466 
467  if constexpr (std::is_same_v<T, SeqStmt>) {
468  // Static type-checking for a SeqStmt that could be flattened.
469  (*this)(0, stmt_or_seq->seq);
470  return;
471  }
472 
473  if constexpr (std::is_base_of_v<T, SeqStmt>) {
474  // Dynamic type-checking for a SeqStmt that could be
475  // flattened.
476  if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
477  operator()(0, op->seq);
478  return;
479  }
480  }
481 
482  if constexpr (std::is_base_of_v<T, Evaluate>) {
483  // Evaluate(0) is used to represent a no-op, and may be
484  // generated by previous calls to SeqStmt::Flatten(). These
485  // should be removed to ensure that Flatten(a+b) is equivalent
486  // to Flatten(Flatten(a), Flatten(b)).
487  if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
488  if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
489  return;
490  }
491  }
492  }
493 
494  if constexpr (std::is_base_of_v<Stmt, T>) {
495  // Any other Stmt type just gets appended.
496  seq_->push_back(stmt_or_seq);
497  } else {
498  // Anything else is treated as an iterable of Stmt.
499  for (auto v : stmt_or_seq) {
500  this->operator()(0, v);
501  }
502  }
503  }
504 
505  private:
506  ffi::Array<Stmt>* seq_;
507  };
508 
511 };
512 
516 class IfThenElseNode : public StmtNode {
517  public:
523  ffi::Optional<Stmt> else_case;
524 
525  static void RegisterReflection() {
526  namespace refl = tvm::ffi::reflection;
527  refl::ObjectDef<IfThenElseNode>()
528  .def_ro("condition", &IfThenElseNode::condition)
529  .def_ro("then_case", &IfThenElseNode::then_case)
530  .def_ro("else_case", &IfThenElseNode::else_case);
531  }
533 };
534 
539 class IfThenElse : public Stmt {
540  public:
541  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case,
542  ffi::Optional<Stmt> else_case = std::nullopt, Span span = Span());
543 
546 };
547 
555 enum class ForKind : int {
557  kSerial = 0,
559  kParallel = 1,
564  kVectorized = 2,
566  kUnrolled = 3,
573  kThreadBinding = 4
574 };
575 
586 class ForNode : public StmtNode {
587  public:
602  ffi::Optional<IterVar> thread_binding;
611  ffi::Map<ffi::String, ffi::Any> annotations;
615  ffi::Optional<PrimExpr> step;
616 
617  static void RegisterReflection() {
618  namespace refl = tvm::ffi::reflection;
619  refl::ObjectDef<ForNode>()
620  .def_ro("loop_var", &ForNode::loop_var, refl::AttachFieldFlag::SEqHashDef())
621  .def_ro("min", &ForNode::min)
622  .def_ro("extent", &ForNode::extent)
623  .def_ro("kind", &ForNode::kind)
624  .def_ro("body", &ForNode::body)
625  .def_ro("thread_binding", &ForNode::thread_binding)
626  .def_ro("annotations", &ForNode::annotations)
627  .def_ro("step", &ForNode::step);
628  }
629 
631  bool HasTrivialStep() const;
632 
634 };
635 
640 class For : public Stmt {
641  public:
642  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
643  ffi::Optional<IterVar> thread_binding = std::nullopt,
644  ffi::Map<ffi::String, ffi::Any> annotations = {},
645  ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
646 
649 };
650 
661 class WhileNode : public StmtNode {
662  public:
667 
668  static void RegisterReflection() {
669  namespace refl = tvm::ffi::reflection;
670  refl::ObjectDef<WhileNode>()
671  .def_ro("condition", &WhileNode::condition)
672  .def_ro("body", &WhileNode::body);
673  }
675 };
676 
681 class While : public Stmt {
682  public:
683  TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
684 
687 };
688 
693  public:
697  ffi::Array<Range> region;
698 
699  static void RegisterReflection() {
700  namespace refl = tvm::ffi::reflection;
701  refl::ObjectDef<BufferRegionNode>()
702  .def_ro("buffer", &BufferRegionNode::buffer)
703  .def_ro("region", &BufferRegionNode::region);
704  }
705 
706  TVM_DLL PrimExpr ToPrimExpr() const final;
707 
708  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
710 };
711 
717  public:
718  TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array<Range> region);
719 
726 
733  TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array<PrimExpr> indices);
734 
737 };
738 
748 class MatchBufferRegionNode : public Object {
749  public:
754 
755  static void RegisterReflection() {
756  namespace refl = tvm::ffi::reflection;
757  refl::ObjectDef<MatchBufferRegionNode>()
758  .def_ro("buffer", &MatchBufferRegionNode::buffer)
759  .def_ro("source", &MatchBufferRegionNode::source);
760  }
761 
762  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
763  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.MatchBufferRegion", MatchBufferRegionNode, Object);
764 };
765 
770 class MatchBufferRegion : public ObjectRef {
771  public:
772  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
773 
776 };
777 
799 class SBlockNode : public StmtNode {
800  public:
802  ffi::Array<IterVar> iter_vars;
804  ffi::Array<BufferRegion> reads;
806  ffi::Array<BufferRegion> writes;
808  ffi::String name_hint;
810  ffi::Array<Buffer> alloc_buffers;
812  ffi::Array<MatchBufferRegion> match_buffers;
814  ffi::Map<ffi::String, ffi::Any> annotations;
822  ffi::Optional<Stmt> init;
825 
826  static void RegisterReflection() {
827  namespace refl = tvm::ffi::reflection;
828  refl::ObjectDef<SBlockNode>()
829  .def_ro("iter_vars", &SBlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef())
830  .def_ro("reads", &SBlockNode::reads)
831  .def_ro("writes", &SBlockNode::writes)
832  .def_ro("name_hint", &SBlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
833  .def_ro("alloc_buffers", &SBlockNode::alloc_buffers)
834  .def_ro("match_buffers", &SBlockNode::match_buffers)
835  .def_ro("annotations", &SBlockNode::annotations)
836  .def_ro("init", &SBlockNode::init)
837  .def_ro("body", &SBlockNode::body);
838  }
840 };
841 
846 class SBlock : public Stmt {
847  public:
848  TVM_DLL explicit SBlock(
849  ffi::Array<IterVar> iter_vars, ffi::Array<BufferRegion> reads,
850  ffi::Array<BufferRegion> writes, ffi::String name_hint, Stmt body,
851  ffi::Optional<Stmt> init = std::nullopt,
852  ffi::Array<Buffer> alloc_buffers = ffi::Array<Buffer>(),
853  ffi::Array<MatchBufferRegion> match_buffers = ffi::Array<MatchBufferRegion>(),
854  ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
855  Span span = Span());
856 
859 };
860 
864 class SBlockRealizeNode : public StmtNode {
865  public:
867  ffi::Array<PrimExpr> iter_values;
875 
876  static void RegisterReflection() {
877  namespace refl = tvm::ffi::reflection;
878  refl::ObjectDef<SBlockRealizeNode>()
879  .def_ro("iter_values", &SBlockRealizeNode::iter_values)
880  .def_ro("predicate", &SBlockRealizeNode::predicate)
881  .def_ro("block", &SBlockRealizeNode::block);
882  }
884 };
885 
890 class SBlockRealize : public Stmt {
891  public:
892  TVM_DLL explicit SBlockRealize(ffi::Array<PrimExpr> iter_values, PrimExpr predicate, SBlock block,
893  Span span = Span());
894 
897 };
898 
900 namespace attr {
902 constexpr const char* buffer_bound = "buffer_bound";
907 constexpr const char* compute_scope = "compute_scope";
909 constexpr const char* device_id = "device_id";
911 constexpr const char* device_scope = "device_scope";
913 constexpr const char* device_type = "device_type";
919 constexpr const char* extern_scope = "extern_scope";
921 constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
923 constexpr const char* pragma_import_c = "pragma_import_c";
925 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
927 constexpr const char* pragma_scope_prefix = "pragma_";
929 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
931 constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
933 constexpr const char* storage_alignment = "storage_alignment";
935 constexpr const char* thread_extent = "thread_extent";
937 constexpr const char* kVolatile = "tirx.volatile";
938 
944 inline bool IsPragmaKey(const std::string& attr_key) {
945  return attr_key.compare(0, 7, "pragma_") == 0;
946 }
947 
948 } // namespace attr
955 TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
956 
957 // overload printing of for type.
958 TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
959 
960 // inline implementations
961 inline const char* ForKind2String(ForKind t) {
962  switch (t) {
963  case ForKind::kSerial:
964  return "serial";
965  case ForKind::kParallel:
966  return "parallel";
968  return "vectorized";
969  case ForKind::kUnrolled:
970  return "unroll";
972  return "thread_binding";
973  }
974  TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t;
975  TVM_FFI_UNREACHABLE();
976 }
977 
978 } // namespace tirx
979 } // namespace tvm
980 #endif // TVM_TIR_STMT_H_
Constant integer literals in the program.
Definition: expr.h:494
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:156
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:167
Reference to PrimExprNode.
Definition: expr.h:126
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
Allocate a buffer and declare it in scope.
Definition: stmt.h:259
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the allocation.
Definition: stmt.h:269
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AllocBuffer", AllocBufferNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:271
Buffer buffer
The buffer being allocated and declared.
Definition: stmt.h:262
Managed reference to AllocBufferNode.
Definition: stmt.h:281
AllocBuffer(Buffer buffer, ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
std::optional< int64_t > ConstantAllocationSize() const
If the buffer's shape is constant, return the total number of elements.
Definition: stmt.h:291
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocBufferNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllocBuffer, Stmt, AllocBufferNode)
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:159
PrimExpr condition
Condition to be checked.
Definition: stmt.h:162
StringImm error_kind
The error kind, e.g. "RuntimeError", "TypeError", "ValueError".
Definition: stmt.h:164
static void RegisterReflection()
Definition: stmt.h:168
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AssertStmt", AssertStmtNode, StmtNode)
ffi::Array< StringImm > message_parts
Error message fragments, concatenated at runtime when assertion fails.
Definition: stmt.h:166
Managed reference to AssertStmtNode.
Definition: stmt.h:182
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode)
AssertStmt(PrimExpr condition, StringImm error_kind, ffi::Array< StringImm > message_parts, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode)
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:115
static void RegisterReflection()
Definition: stmt.h:126
PrimExpr value
The attribute value, value is well defined at current scope.
Definition: stmt.h:122
ffi::Any node
this is attribute about certain node
Definition: stmt.h:118
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AttrStmt", AttrStmtNode, StmtNode)
ffi::String attr_key
the type key of the attribute
Definition: stmt.h:120
Stmt body
The body statement to be executed.
Definition: stmt.h:124
Managed reference to AttrStmtNode.
Definition: stmt.h:141
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrStmt, Stmt, AttrStmtNode)
AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode)
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:77
PrimExpr value
The value to bind to the variable.
Definition: stmt.h:82
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Bind", BindNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:84
Var var
The variable being bound.
Definition: stmt.h:80
Managed reference to BindNode.
Definition: stmt.h:97
Bind(Var var, PrimExpr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Bind, Stmt, BindNode)
Representing the region of multi-dimensional buffer access.
Definition: stmt.h:692
ffi::Array< Range > region
The region array of the buffer region.
Definition: stmt.h:697
PrimExpr ToPrimExpr() const final
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferRegion", BufferRegionNode, PrimExprConvertibleNode)
static void RegisterReflection()
Definition: stmt.h:699
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:708
Buffer buffer
The buffer of the buffer region.
Definition: stmt.h:695
Managed reference to BufferRegionNode.
Definition: stmt.h:716
static BufferRegion FullRegion(Buffer buffer)
Create a BufferRegion which is full region of the given buffer.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferRegion, PrimExprConvertible, BufferRegionNode)
BufferRegion(Buffer buffer, ffi::Array< Range > region)
static BufferRegion FromPoint(Buffer buffer, ffi::Array< PrimExpr > indices)
Create a BufferRegion which is a single point of the given buffer.
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode)
Store value to the high dimension buffer.
Definition: stmt.h:201
Buffer buffer
The buffer variable.
Definition: stmt.h:204
ffi::Array< PrimExpr > indices
The indices location to be stored.
Definition: stmt.h:208
static void RegisterReflection()
Definition: stmt.h:212
PrimExpr value
The value to be stored.
Definition: stmt.h:206
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferStore", BufferStoreNode, StmtNode)
ffi::Optional< PrimExpr > predicate
The predicate mask for storing values.
Definition: stmt.h:210
Managed reference to BufferStoreNode.
Definition: stmt.h:227
BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferStore, Stmt, BufferStoreNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Declare a buffer that can be used in the body.
Definition: stmt.h:238
Buffer buffer
The buffer being declared.
Definition: stmt.h:241
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DeclBuffer", DeclBufferNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:243
Managed reference to DeclBufferNode.
Definition: stmt.h:251
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeclBuffer, Stmt, DeclBufferNode)
DeclBuffer(Buffer buffer, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode)
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:336
PrimExpr value
The expression to be evaluated.
Definition: stmt.h:339
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Evaluate", EvaluateNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:341
Managed reference to EvaluateNode.
Definition: stmt.h:352
Evaluate(PrimExpr value, Span span=Span())
Evaluate(int value, Span span=Span())
Definition: stmt.h:356
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Evaluate, Stmt, EvaluateNode)
A for loop, with possible type annotations.
Definition: stmt.h:586
PrimExpr min
The minimum value of iteration.
Definition: stmt.h:591
ffi::Optional< IterVar > thread_binding
Only valid when kind == ForKind::kThreadBinding The context thread that this loop variable bounds to.
Definition: stmt.h:602
ffi::Optional< PrimExpr > step
The loop step. It is one if not specified.
Definition: stmt.h:615
Var loop_var
The loop variable.
Definition: stmt.h:589
bool HasTrivialStep() const
Check it is a loop without nontrivial loop step.
Stmt body
The body of the for loop.
Definition: stmt.h:597
PrimExpr extent
The extent of the iteration.
Definition: stmt.h:593
ForKind kind
The kind of the for loop.
Definition: stmt.h:595
ffi::Map< ffi::String, ffi::Any > annotations
Additional annotations about the loop.
Definition: stmt.h:611
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.For", ForNode, StmtNode)
static void RegisterReflection()
Definition: stmt.h:617
Managed reference to ForNode.
Definition: stmt.h:640
For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional< IterVar > thread_binding=std::nullopt, ffi::Map< ffi::String, ffi::Any > annotations={}, ffi::Optional< PrimExpr > step=std::nullopt, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode)
IfThenElse statement.
Definition: stmt.h:516
static void RegisterReflection()
Definition: stmt.h:525
ffi::Optional< Stmt > else_case
The branch to be executed when condition is false, can be null.
Definition: stmt.h:523
PrimExpr condition
The condition.
Definition: stmt.h:519
Stmt then_case
The branch to be executed when condition is true.
Definition: stmt.h:521
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IfThenElse", IfThenElseNode, StmtNode)
Managed reference to IfThenElseNode.
Definition: stmt.h:539
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IfThenElse, Stmt, IfThenElseNode)
IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional< Stmt > else_case=std::nullopt, Span span=Span())
Match introduces a constraint that the source buffer region can be remapped to the data layout specif...
Definition: stmt.h:748
Buffer buffer
The target buffer.
Definition: stmt.h:751
BufferRegion source
The source buffer region.
Definition: stmt.h:753
static void RegisterReflection()
Definition: stmt.h:755
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.MatchBufferRegion", MatchBufferRegionNode, Object)
Managed reference to MatchBufferRegionNode.
Definition: stmt.h:770
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchBufferRegion, ObjectRef, MatchBufferRegionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode)
MatchBufferRegion(Buffer buffer, BufferRegion source)
A block is a basic schedule unit in TIR.
Definition: stmt.h:799
ffi::Array< Buffer > alloc_buffers
The buffer allocated in the block.
Definition: stmt.h:810
Stmt body
The body of the block.
Definition: stmt.h:824
ffi::Array< BufferRegion > reads
The read buffer regions of the block.
Definition: stmt.h:804
static void RegisterReflection()
Definition: stmt.h:826
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlock", SBlockNode, StmtNode)
ffi::Array< MatchBufferRegion > match_buffers
The match buffer regions.
Definition: stmt.h:812
ffi::String name_hint
The name_hint of the block.
Definition: stmt.h:808
ffi::Optional< Stmt > init
The init statement is executed during the first iteration of reduction loops in a reduction block....
Definition: stmt.h:822
ffi::Array< IterVar > iter_vars
The variables of the block.
Definition: stmt.h:802
ffi::Array< BufferRegion > writes
The write buffer regions of the block.
Definition: stmt.h:806
ffi::Map< ffi::String, ffi::Any > annotations
The annotation of the block.
Definition: stmt.h:814
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:864
static void RegisterReflection()
Definition: stmt.h:876
PrimExpr predicate
The predicate of the block realization, the block will only be executed when the predicate is true.
Definition: stmt.h:872
ffi::Array< PrimExpr > iter_values
The corresponding values of the iter vars.
Definition: stmt.h:867
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlockRealize", SBlockRealizeNode, StmtNode)
SBlock block
The block to be realized.
Definition: stmt.h:874
Managed reference to BlockRealizeNode.
Definition: stmt.h:890
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode)
SBlockRealize(ffi::Array< PrimExpr > iter_values, PrimExpr predicate, SBlock block, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlockRealize, Stmt, SBlockRealizeNode)
Managed reference to SBlockNode.
Definition: stmt.h:846
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlock, Stmt, SBlockNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockNode)
SBlock(ffi::Array< IterVar > iter_vars, ffi::Array< BufferRegion > reads, ffi::Array< BufferRegion > writes, ffi::String name_hint, Stmt body, ffi::Optional< Stmt > init=std::nullopt, ffi::Array< Buffer > alloc_buffers=ffi::Array< Buffer >(), ffi::Array< MatchBufferRegion > match_buffers=ffi::Array< MatchBufferRegion >(), ffi::Map< ffi::String, ffi::Any > annotations=ffi::Map< ffi::String, ffi::Any >(), Span span=Span())
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:311
size_t size() const
Definition: stmt.h:317
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:321
ffi::Array< Stmt > seq
internal sequence content.
Definition: stmt.h:314
static void RegisterReflection()
Definition: stmt.h:323
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SeqStmt", SeqStmtNode, StmtNode)
Helper class to flatten sequence of arguments into Array.
Definition: stmt.h:436
static ffi::Optional< SeqStmt > AsSeqStmt(const T &t)
Definition: stmt.h:441
void operator()(size_t i, const T &stmt_or_seq) const
Definition: stmt.h:459
Flattener(ffi::Array< Stmt > *seq)
Definition: stmt.h:438
Sequence statement.
Definition: stmt.h:363
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode)
Stmt operator[](size_t index) const
Get the index-th element in the sequence.
Definition: stmt.h:377
static Stmt Flatten(Args &&... seq_args)
Construct a sequence statement by flattening all the arrays and sequences in the arguments recursivel...
Definition: stmt.h:399
size_t size() const
Definition: stmt.h:373
SeqStmt(ffi::Array< Stmt > seq, Span span=Span())
Construct SeqStmt.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqStmt, Stmt, SeqStmtNode)
Base node of all statements.
Definition: stmt.h:40
static constexpr const uint32_t _type_child_slots
Definition: stmt.h:60
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: stmt.h:58
Span span
Span that points to the original source code. Reserved debug information.
Definition: stmt.h:46
StmtNode(Span span)
Definition: stmt.h:49
TVM_FFI_DECLARE_OBJECT_INFO("tirx.Stmt", StmtNode, Object)
static void RegisterReflection()
Definition: stmt.h:51
Container of all statements.
Definition: stmt.h:65
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ObjectRef, StmtNode)
Managed reference to StringImmNode.
Definition: expr.h:68
a named variable in TIR
Definition: var.h:76
A While loop.
Definition: stmt.h:661
static void RegisterReflection()
Definition: stmt.h:668
PrimExpr condition
The termination condition.
Definition: stmt.h:664
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.While", WhileNode, StmtNode)
Stmt body
The body of the while loop.
Definition: stmt.h:666
Managed reference to WhileNode.
Definition: stmt.h:681
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode)
While(PrimExpr condition, Stmt body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(While, Stmt, WhileNode)
Definition: repr_printer.h:91
void Evaluate(PrimExpr value)
Evaluate the input expression.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * thread_extent
Mark launching extent of thread, used by device API.
Definition: stmt.h:935
bool IsPragmaKey(const std::string &attr_key)
Check if attr_key is a pragma key extension.
Definition: stmt.h:944
constexpr const char * pragma_unroll_explicit
Pragma: unroll explicit.
Definition: stmt.h:931
constexpr const char * storage_alignment
Mark storage alignment requirement of buffers.
Definition: stmt.h:933
constexpr const char * buffer_bound
Mark stores/loads with their bounds.
Definition: stmt.h:902
constexpr const char * device_id
The allocation device for global malloc in host.
Definition: stmt.h:909
constexpr const char * pragma_auto_unroll_max_step
Pragma: auto-unroll, max_step.
Definition: stmt.h:921
constexpr const char * pragma_import_llvm
Import llvm source or file into the final code gen module.
Definition: stmt.h:925
constexpr const char * pragma_import_c
Import C source or file into the final code gen module.
Definition: stmt.h:923
constexpr const char * compute_scope
Mark the scope as when computation start to happen. This can hint some code generator to create a new...
Definition: stmt.h:907
constexpr const char * pragma_tensor_core
Try to modify the AST to support Tensor Core.
Definition: stmt.h:929
constexpr const char * device_type
The device type.
Definition: stmt.h:913
constexpr const char * kVolatile
Annotation key on AllocBuffer marking the allocation as volatile.
Definition: stmt.h:937
constexpr const char * device_scope
Mark that it is in the device scope.
Definition: stmt.h:911
constexpr const char * extern_scope
Mark the scope as generated by extern primitive. Such scope can contain arbitrary ir program and we n...
Definition: stmt.h:919
constexpr const char * pragma_scope_prefix
Mark region is guarded by the pragma extension.
Definition: stmt.h:927
@ kUnrolled
The execution is unrolled.
Definition: var.h:232
@ kVectorized
The loop is vectorized.
Definition: var.h:236
PrimExpr TypeAnnotation(DataType dtype, Span span=Span())
Create a type annotation expression.
std::ostream & operator<<(std::ostream &os, CallEffectKind side_effect)
Definition: op_attr_types.h:123
ForKind
The kind of the loop.
Definition: stmt.h:555
@ kThreadBinding
The loop variable is bound to a thread in an environment. In the final stage of lowering,...
@ kParallel
Parallel execution on CPU.
@ kSerial
default semantics – serial execution.
const char * ForKind2String(ForKind t)
Definition: stmt.h:961
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
TIR expressions.