tvm
stmt_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TIRX_STMT_FUNCTOR_H_
27 #define TVM_TIRX_STMT_FUNCTOR_H_
28 
29 #include <tvm/ir/node_functor.h>
30 #include <tvm/tirx/expr.h>
31 #include <tvm/tirx/expr_functor.h>
32 #include <tvm/tirx/function.h>
33 #include <tvm/tirx/stmt.h>
34 #include <tvm/tirx/tirx_stmt.h>
35 
36 #include <unordered_map>
37 #include <utility>
38 
39 namespace tvm {
40 namespace tirx {
46 template <typename FType>
48 
49 #define STMT_FUNCTOR_DEFAULT \
50  { \
51  return VisitStmtDefault_(op, std::forward<Args>(args)...); \
52  }
53 
54 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
55  vtable.template set_dispatch<OP>([](const ffi::ObjectRef& n, TSelf* self, Args... args) { \
56  return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
57  });
58 
59 template <typename R, typename... Args>
60 class StmtFunctor<R(const Stmt& n, Args... args)> {
61  private:
62  using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
63  using FType = NodeFunctor<R(const ffi::ObjectRef& n, TSelf* self, Args... args)>;
64 
65  public:
67  using result_type = R;
69  virtual ~StmtFunctor() {}
76  R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
83  virtual R VisitStmt(const Stmt& n, Args... args) {
84  static FType vtable = InitVTable();
85  return vtable(n, this, std::forward<Args>(args)...);
86  }
87  // Functions that can be overriden by subclass
88  virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89  virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90  virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
91  virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92  virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
93  virtual R VisitStmt_(const BreakNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
94  virtual R VisitStmt_(const ContinueNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95  virtual R VisitStmt_(const AllocBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96  virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97  virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98  virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
99  virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100  virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
101  virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
102  virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
103  virtual R VisitStmt_(const ExecScopeStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
104  virtual R VisitStmt_(const tirx::TilePrimitiveCallNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
105  virtual R VisitStmtDefault_(const ffi::Object* op, Args...) {
106  TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey();
107  TVM_FFI_UNREACHABLE();
108  }
109 
110  private:
111  // initialize the vtable.
112  static FType InitVTable() {
113  FType vtable;
131  vtable.Finalize();
132  return vtable;
133  }
134 };
135 
136 #undef IR_STMT_FUNCTOR_DISPATCH
137 #undef STMT_FUNCTOR_DEFAULT
138 
142 class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
143  public:
144  using StmtFunctor::operator();
145 
146  protected:
147  using StmtFunctor::VisitStmt;
155  virtual void VisitExpr(const PrimExpr& e) {}
163  virtual void VisitBufferDef(const Buffer& buffer, bool alloc_data);
169  virtual void VisitBufferUse(const Buffer& buffer);
170  // statement visitor
171  void VisitStmt_(const BindNode* op) override;
172  void VisitStmt_(const AttrStmtNode* op) override;
173  void VisitStmt_(const IfThenElseNode* op) override;
174  void VisitStmt_(const ForNode* op) override;
175  void VisitStmt_(const WhileNode* op) override;
176  void VisitStmt_(const BreakNode* op) override;
177  void VisitStmt_(const ContinueNode* op) override;
178  void VisitStmt_(const AllocBufferNode* op) override;
179  void VisitStmt_(const DeclBufferNode* op) override;
180  void VisitStmt_(const BufferStoreNode* op) override;
181  void VisitStmt_(const AssertStmtNode* op) override;
182  void VisitStmt_(const SeqStmtNode* op) override;
183  void VisitStmt_(const EvaluateNode* op) override;
184  void VisitStmt_(const SBlockNode* op) override;
185  void VisitStmt_(const SBlockRealizeNode* op) override;
186  void VisitStmt_(const ExecScopeStmtNode* op) override;
187  void VisitStmt_(const tirx::TilePrimitiveCallNode* op) override;
188 };
189 
193 class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
194  public:
204  allow_copy_on_write_ = true;
205  return VisitStmt(stmt);
206  }
207 
208  protected:
210  ffi::Map<Buffer, Buffer> buffer_remap_;
211  // We perform copy on write optimizations on the StmtMutator
212  // so that an unique copy of parent can be mutated inplace
213  // when some of its children changed.
214  // We only do such optimization for Stmt nests(instead of Exprs) for now
215  // as Stmt's parent state is more likely remain unchanged when one of
216  // its child block changes.
221  bool allow_copy_on_write_{false};
231  template <typename TNode>
232  ffi::ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
233  static_assert(std::is_base_of<StmtNode, TNode>::value,
234  "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
235  "nodes during the recursion. Because the child classes do not necessarily "
236  "check the Array, Expr and other structures during the visit, it is only safe to "
237  "call this function with StmtNodes for now. "
238  "Please create a new node directly in other cases.");
239  if (allow_copy_on_write_) {
240  // return the old node.
241  return ffi::GetObjectPtr<TNode>(const_cast<TNode*>(node));
242  } else {
243  // Make a new copy of the node.
244  // need to rely on the default copy constructor
245  return ffi::make_object<TNode>(*node);
246  }
247  }
254  Stmt VisitStmt(const Stmt& stmt) override {
255  if (allow_copy_on_write_ && !stmt.unique()) {
256  allow_copy_on_write_ = false;
257  Stmt ret = StmtFunctor::VisitStmt(stmt);
258  allow_copy_on_write_ = true;
259  return ret;
260  } else {
261  return StmtFunctor::VisitStmt(stmt);
262  }
263  }
271  virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
280  virtual Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data);
287  virtual Buffer VisitBufferUse(const Buffer& buffer);
288  // statement visitor
289  Stmt VisitStmt_(const BindNode* op) override;
290  Stmt VisitStmt_(const AttrStmtNode* op) override;
291  Stmt VisitStmt_(const IfThenElseNode* op) override;
292  Stmt VisitStmt_(const ForNode* op) override;
293  Stmt VisitStmt_(const WhileNode* op) override;
294  Stmt VisitStmt_(const BreakNode* op) override;
295  Stmt VisitStmt_(const ContinueNode* op) override;
296  Stmt VisitStmt_(const AllocBufferNode* op) override;
297  Stmt VisitStmt_(const DeclBufferNode* op) override;
298  Stmt VisitStmt_(const BufferStoreNode* op) override;
299  Stmt VisitStmt_(const AssertStmtNode* op) override;
300  Stmt VisitStmt_(const SeqStmtNode* op) override;
301  Stmt VisitStmt_(const EvaluateNode* op) override;
302  Stmt VisitStmt_(const SBlockNode* op) override;
303  Stmt VisitStmt_(const SBlockRealizeNode* op) override;
304  Stmt VisitStmt_(const ExecScopeStmtNode* op) override;
318  Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
319  std::function<Stmt(const Stmt&)> fmutate = nullptr);
320 
321  // internal helper.
322  class Internal;
323 };
324 
328 class TVM_DLL StmtExprVisitor : public ExprVisitor, public StmtVisitor {
329  public:
330  using StmtVisitor::operator();
331  using ExprVisitor::operator();
332 
333  protected:
334  using ExprVisitor::VisitExpr;
336  using StmtVisitor::VisitStmt;
337 
338  void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
339  void VisitExpr_(const BufferLoadNode* op) override;
340 };
341 
345 class TVM_DLL StmtExprMutator : public ExprMutator, public StmtMutator {
346  public:
347  using StmtMutator::operator();
348  using ExprMutator::operator();
349 
350  protected:
351  using ExprMutator::VisitExpr;
354 
355  PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
356  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
357 };
358 
374 TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder,
375  ffi::Optional<ffi::Array<ffi::String>> only_enable = std::nullopt);
376 
383 TVM_DLL void PostOrderVisit(const ffi::ObjectRef& node,
384  std::function<void(const ffi::ObjectRef&)> fvisit);
385 
392 TVM_DLL Stmt Substitute(Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap);
393 
401  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap);
402 
409 template <typename T>
410 ffi::Array<T> Substitute(const ffi::Array<T>& arr,
411  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap) {
412  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
413 }
414 
421 inline Range Substitute(const Range& range,
422  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap) {
423  return Range::FromMinExtent(Substitute(range->min, vmap), Substitute(range->extent, vmap));
424 }
425 
437 template <typename Obj>
438 auto Substitute(Obj&& obj, const ffi::Map<Var, PrimExpr>& vmap) {
439  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> { return vmap.Get(var); };
440  return Substitute(std::forward<Obj>(obj), func);
441 }
442 
452 template <typename Obj, typename Expr,
453  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
454 auto Substitute(Obj&& obj, const ffi::Map<Var, Expr>& vmap) {
455  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
456  if (auto opt = vmap.Get(var)) {
457  return opt.value();
458  } else {
459  return std::nullopt;
460  }
461  };
462  return Substitute(std::forward<Obj>(obj), func);
463 }
464 
474 template <typename Obj, typename Expr,
475  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
476 auto Substitute(Obj&& obj, const std::unordered_map<const VarNode*, Expr>& vmap) {
477  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
478  if (auto it = vmap.find(var.get()); it != vmap.end()) {
479  return it->second;
480  } else {
481  return std::nullopt;
482  }
483  };
484  return Substitute(std::forward<Obj>(obj), func);
485 }
486 
496 template <typename Obj, typename Expr, typename Hasher, typename EqualityChecker,
497  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
498 auto Substitute(Obj&& obj, const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
499  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
500  if (auto it = vmap.find(var); it != vmap.end()) {
501  return it->second;
502  } else {
503  return std::nullopt;
504  }
505  };
506  return Substitute(std::forward<Obj>(obj), func);
507 }
508 
518 template <typename Obj, typename Expr,
519  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
520 auto Substitute(Obj&& obj, const std::unordered_map<IterVar, Expr>& iter_vmap) {
521  std::unordered_map<const VarNode*, PrimExpr> vmap;
522  for (const auto& [iter_var, expr] : iter_vmap) {
523  vmap[iter_var->var.get()] = expr;
524  }
525 
526  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
527  if (auto it = vmap.find(var.get()); it != vmap.end()) {
528  return it->second;
529  } else {
530  return std::nullopt;
531  }
532  };
533  return Substitute(std::forward<Obj>(obj), func);
534 }
535 
547  Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap);
548 
560  PrimExpr expr, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap);
561 
569 TVM_DLL void PreOrderVisit(const ffi::ObjectRef& stmt_or_expr,
570  const std::function<bool(const ffi::ObjectRef&)>& fvisit);
571 
582 template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
583 bool ContainsNode(const Stmt& stmt) {
584  struct Visitor : StmtVisitor {
585  // Early bail-out, if we already found the node.
586  void VisitStmt(const Stmt& stmt) final {
587  if (contains_node) {
588  return;
589  }
590  StmtVisitor::VisitStmt(stmt);
591  }
592 
593  void VisitStmt_(const Node* block) override { contains_node = true; }
594 
595  bool contains_node{false};
596  };
597 
598  Visitor visitor;
599  visitor(stmt);
600  return visitor.contains_node;
601 }
602 
603 } // namespace tirx
604 } // namespace tvm
605 
606 #endif // TVM_TIR_STMT_FUNCTOR_H_
A dynamically dispatched functor on the type of the first argument.
Definition: node_functor.h:62
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer and declare it in scope.
Definition: stmt.h:261
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:161
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:79
A Break in control flow.
Definition: stmt.h:694
Load value from the high dimension buffer.
Definition: expr.h:533
Store value to the high dimension buffer.
Definition: stmt.h:203
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
A Continue in control flow.
Definition: stmt.h:719
Declare a buffer that can be used in the body.
Definition: stmt.h:240
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:338
A statement that annotates the execution scope for its body.
Definition: stmt.h:969
ExprMutator that mutates expressions.
Definition: expr_functor.h:253
PrimExpr VisitExpr_(const VarNode *op) override
ExprVisitor.
Definition: expr_functor.h:208
void VisitExpr_(const VarNode *op) override
A for loop, with possible type annotations.
Definition: stmt.h:588
IfThenElse statement.
Definition: stmt.h:518
A block is a basic schedule unit in TIR.
Definition: stmt.h:852
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:921
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:313
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:345
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:355
PrimExpr VisitExpr_(const BufferLoadNode *op) override
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:328
void VisitExpr_(const BufferLoadNode *op) override
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:338
virtual R VisitStmt_(const ContinueNode *op, Args... args)
Definition: stmt_functor.h:94
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:97
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:69
virtual R VisitStmtDefault_(const ffi::Object *op, Args...)
Definition: stmt_functor.h:105
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:76
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:90
virtual R VisitStmt_(const ExecScopeStmtNode *op, Args... args)
Definition: stmt_functor.h:103
virtual R VisitStmt_(const SBlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:102
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt_(const SBlockNode *op, Args... args)
Definition: stmt_functor.h:101
virtual R VisitStmt_(const tirx::TilePrimitiveCallNode *op, Args... args)
Definition: stmt_functor.h:104
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:83
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:100
virtual R VisitStmt_(const BreakNode *op, Args... args)
Definition: stmt_functor.h:93
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:98
R result_type
the result type of this functor
Definition: stmt_functor.h:67
virtual R VisitStmt_(const AllocBufferNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const BindNode *op, Args... args)
Definition: stmt_functor.h:88
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:47
StmtMutator that mutates the statements.
Definition: stmt_functor.h:193
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:203
Stmt VisitStmt_(const tirx::TilePrimitiveCallNode *op) override
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt_(const SBlockNode *op) override
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const ContinueNode *op) override
Stmt VisitStmt_(const SBlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual Buffer VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site. Visits shape/strides/elem_offset via VisitExpr. If any field changes...
ffi::Map< Buffer, Buffer > buffer_remap_
Map from old buffer to new buffer, populated by VisitBufferDef.
Definition: stmt_functor.h:210
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:254
Stmt VisitStmt_(const BindNode *op) override
Stmt VisitStmt_(const ForNode *op) override
ffi::ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:232
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
virtual Buffer VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
Stmt VisitStmt_(const BreakNode *op) override
Stmt VisitStmt_(const AllocBufferNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:271
Stmt VisitStmt_(const ExecScopeStmtNode *op) override
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
StmtVisitor.
Definition: stmt_functor.h:142
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const BindNode *op) override
void VisitStmt_(const SBlockNode *op) override
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
void VisitStmt_(const SBlockRealizeNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
void VisitStmt_(const tirx::TilePrimitiveCallNode *op) override
void VisitStmt_(const ContinueNode *op) override
virtual void VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const BreakNode *op) override
void VisitStmt_(const EvaluateNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:155
void VisitStmt_(const ExecScopeStmtNode *op) override
void VisitStmt_(const IfThenElseNode *op) override
virtual void VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock alloc_buffers). Visits buffer shape,...
void VisitStmt_(const AllocBufferNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
Container of all statements.
Definition: stmt.h:67
TIRX TilePrimitiveCall stmt.
Definition: tirx_stmt.h:35
a named variable in TIR
Definition: var.h:77
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:124
A While loop.
Definition: stmt.h:663
tvm::relax::Function Function
Definition: transform.h:38
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
void PostOrderVisit(const ffi::ObjectRef &node, std::function< void(const ffi::ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
IndexMap Substitute(const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:583
void PreOrderVisit(const ffi::ObjectRef &stmt_or_expr, const std::function< bool(const ffi::ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Defines the Functor data structures.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:54
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:49
TIR expressions.
Functors for tirx expressions.
TIR Function.
TIR statements.