tvm
stmt_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TIR_STMT_FUNCTOR_H_
27 #define TVM_TIR_STMT_FUNCTOR_H_
28 
29 #include <tvm/node/functor.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/expr_functor.h>
32 #include <tvm/tir/function.h>
33 #include <tvm/tir/stmt.h>
34 
35 #include <unordered_map>
36 #include <utility>
37 
38 namespace tvm {
39 namespace tir {
45 template <typename FType>
47 
48 #define STMT_FUNCTOR_DEFAULT \
49  { return VisitStmtDefault_(op, std::forward<Args>(args)...); }
50 
51 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
52  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
53  return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
54  });
55 
56 template <typename R, typename... Args>
57 class StmtFunctor<R(const Stmt& n, Args... args)> {
58  private:
59  using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
60  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
61 
62  public:
64  using result_type = R;
66  virtual ~StmtFunctor() {}
73  R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
80  virtual R VisitStmt(const Stmt& n, Args... args) {
81  static FType vtable = InitVTable();
82  return vtable(n, this, std::forward<Args>(args)...);
83  }
84  // Functions that can be overriden by subclass
85  virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
86  virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
87  virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
88  virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89  virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90  virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
91  virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92  virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
93  virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
94  virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95  virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96  virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97  virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98  virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
99  virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100  virtual R VisitStmtDefault_(const Object* op, Args...) {
101  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
102  TVM_FFI_UNREACHABLE();
103  }
104 
105  private:
106  // initialize the vtable.
107  static FType InitVTable() {
108  FType vtable;
124  vtable.Finalize();
125  return vtable;
126  }
127 };
128 
129 #undef IR_STMT_FUNCTOR_DISPATCH
130 #undef STMT_FUNCTOR_DEFAULT
131 
135 class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
136  public:
137  using StmtFunctor::operator();
138 
139  protected:
140  using StmtFunctor::VisitStmt;
148  virtual void VisitExpr(const PrimExpr& e) {}
149  // statement visitor
150  void VisitStmt_(const AttrStmtNode* op) override;
151  void VisitStmt_(const IfThenElseNode* op) override;
152  void VisitStmt_(const LetStmtNode* op) override;
153  void VisitStmt_(const ForNode* op) override;
154  void VisitStmt_(const WhileNode* op) override;
155  void VisitStmt_(const AllocateNode* op) override;
156  void VisitStmt_(const AllocateConstNode* op) override;
157  void VisitStmt_(const DeclBufferNode* op) override;
158  void VisitStmt_(const BufferStoreNode* op) override;
159  void VisitStmt_(const BufferRealizeNode* op) override;
160  void VisitStmt_(const AssertStmtNode* op) override;
161  void VisitStmt_(const SeqStmtNode* op) override;
162  void VisitStmt_(const EvaluateNode* op) override;
163  void VisitStmt_(const BlockNode* op) override;
164  void VisitStmt_(const BlockRealizeNode* op) override;
165 };
166 
170 class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
171  public:
181  allow_copy_on_write_ = true;
182  return VisitStmt(stmt);
183  }
184 
185  protected:
186  // We perform copy on write optimizations on the StmtMutator
187  // so that an unique copy of parent can be mutated inplace
188  // when some of its children changed.
189  // We only do such optimization for Stmt nests(instead of Exprs) for now
190  // as Stmt's parent state is more likely remain unchanged when one of
191  // its child block changes.
196  bool allow_copy_on_write_{false};
206  template <typename TNode>
207  ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
208  static_assert(std::is_base_of<StmtNode, TNode>::value,
209  "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
210  "nodes during the recursion. Because the child classes do not necessarily "
211  "check the Array, Expr and other structures during the visit, it is only safe to "
212  "call this function with StmtNodes for now. "
213  "Please create a new node directly in other cases.");
214  if (allow_copy_on_write_) {
215  // return the old node.
216  return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
217  } else {
218  // Make a new copy of the node.
219  // need to rely on the default copy constructor
220  return ffi::make_object<TNode>(*node);
221  }
222  }
229  Stmt VisitStmt(const Stmt& stmt) override {
230  if (allow_copy_on_write_ && !stmt.unique()) {
231  allow_copy_on_write_ = false;
232  Stmt ret = StmtFunctor::VisitStmt(stmt);
233  allow_copy_on_write_ = true;
234  return ret;
235  } else {
236  return StmtFunctor::VisitStmt(stmt);
237  }
238  }
246  virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
247  // statement visitor
248  Stmt VisitStmt_(const AttrStmtNode* op) override;
249  Stmt VisitStmt_(const IfThenElseNode* op) override;
250  Stmt VisitStmt_(const LetStmtNode* op) override;
251  Stmt VisitStmt_(const ForNode* op) override;
252  Stmt VisitStmt_(const WhileNode* op) override;
253  Stmt VisitStmt_(const AllocateNode* op) override;
254  Stmt VisitStmt_(const AllocateConstNode* op) override;
255  Stmt VisitStmt_(const DeclBufferNode* op) override;
256  Stmt VisitStmt_(const BufferStoreNode* op) override;
257  Stmt VisitStmt_(const BufferRealizeNode* op) override;
258  Stmt VisitStmt_(const AssertStmtNode* op) override;
259  Stmt VisitStmt_(const SeqStmtNode* op) override;
260  Stmt VisitStmt_(const EvaluateNode* op) override;
261  Stmt VisitStmt_(const BlockNode* op) override;
262  Stmt VisitStmt_(const BlockRealizeNode* op) override;
275  Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
276  std::function<Stmt(const Stmt&)> fmutate = nullptr);
277 
278  // internal helper.
279  class Internal;
280 };
281 
285 class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
286  public:
287  using StmtVisitor::operator();
288  using ExprVisitor::operator();
289 
290  protected:
291  using ExprVisitor::VisitExpr;
292  using StmtVisitor::VisitStmt;
293 
294  void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
295 };
296 
300 class StmtExprMutator : public StmtMutator, public ExprMutator {
301  public:
302  using StmtMutator::operator();
303  using ExprMutator::operator();
304 
305  protected:
306  using ExprMutator::VisitExpr;
308 
309  PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
310 };
311 
327 TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder,
328  Optional<Array<String>> only_enable = std::nullopt);
329 
336 TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
337 
344 TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& var)> vmap);
345 
352 TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);
353 
360 template <typename T>
361 Array<T> Substitute(const Array<T>& arr, std::function<Optional<PrimExpr>(const Var& var)> vmap) {
362  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
363 }
364 
371 inline Range Substitute(const Range& range,
372  std::function<Optional<PrimExpr>(const Var& var)> vmap) {
373  return Range::FromMinExtent(Substitute(range->min, vmap), Substitute(range->extent, vmap));
374 }
375 
387 template <typename Obj>
388 auto Substitute(Obj&& obj, const Map<Var, PrimExpr>& vmap) {
389  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> { return vmap.Get(var); };
390  return Substitute(std::forward<Obj>(obj), func);
391 }
392 
402 template <typename Obj, typename Expr,
403  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
404 auto Substitute(Obj&& obj, const Map<Var, Expr>& vmap) {
405  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
406  if (auto opt = vmap.Get(var)) {
407  return opt.value();
408  } else {
409  return std::nullopt;
410  }
411  };
412  return Substitute(std::forward<Obj>(obj), func);
413 }
414 
424 template <typename Obj, typename Expr,
425  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
426 auto Substitute(Obj&& obj, const std::unordered_map<const VarNode*, Expr>& vmap) {
427  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
428  if (auto it = vmap.find(var.get()); it != vmap.end()) {
429  return it->second;
430  } else {
431  return std::nullopt;
432  }
433  };
434  return Substitute(std::forward<Obj>(obj), func);
435 }
436 
446 template <typename Obj, typename Expr, typename Hasher, typename EqualityChecker,
447  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
448 auto Substitute(Obj&& obj, const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
449  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
450  if (auto it = vmap.find(var); it != vmap.end()) {
451  return it->second;
452  } else {
453  return std::nullopt;
454  }
455  };
456  return Substitute(std::forward<Obj>(obj), func);
457 }
458 
468 template <typename Obj, typename Expr,
469  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
470 auto Substitute(Obj&& obj, const std::unordered_map<IterVar, Expr>& iter_vmap) {
471  std::unordered_map<const VarNode*, PrimExpr> vmap;
472  for (const auto& [iter_var, expr] : iter_vmap) {
473  vmap[iter_var->var.get()] = expr;
474  }
475 
476  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
477  if (auto it = vmap.find(var.get()); it != vmap.end()) {
478  return it->second;
479  } else {
480  return std::nullopt;
481  }
482  };
483  return Substitute(std::forward<Obj>(obj), func);
484 }
485 
497  std::function<Optional<PrimExpr>(const Var&)> vmap);
498 
510  PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap);
511 
519 TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
520  const std::function<bool(const ObjectRef&)>& fvisit);
521 
529 TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
530 
541 template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
542 bool ContainsNode(const Stmt& stmt) {
543  struct Visitor : StmtVisitor {
544  // Early bail-out, if we already found the node.
545  void VisitStmt(const Stmt& stmt) final {
546  if (contains_node) {
547  return;
548  }
549  StmtVisitor::VisitStmt(stmt);
550  }
551 
552  void VisitStmt_(const Node* block) override { contains_node = true; }
553 
554  bool contains_node{false};
555  };
556 
557  Visitor visitor;
558  visitor(stmt);
559  return visitor.contains_node;
560 }
561 
562 } // namespace tir
563 } // namespace tvm
564 
565 #endif // TVM_TIR_STMT_FUNCTOR_H_
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Reference to PrimExprNode.
Definition: expr.h:129
Range container
Definition: expr.h:698
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer that can be used in body.
Definition: stmt.h:360
Allocate a buffer that can be used in body.
Definition: stmt.h:293
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:154
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:115
A block is a basic schedule unit in TIR.
Definition: stmt.h:955
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1021
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:248
Store value to the high dimension buffer.
Definition: stmt.h:200
Declare a buffer that can be used in the body.
Definition: stmt.h:435
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:492
ExprMutator that mutates expressions.
Definition: expr_functor.h:251
ExprVisitor.
Definition: expr_functor.h:206
A for loop, with possible type annotations.
Definition: stmt.h:746
IfThenElse statement.
Definition: stmt.h:674
Let binding, bind var to value, then run body.
Definition: stmt.h:72
Managed reference to PrimFuncNode.
Definition: function.h:131
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:465
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:300
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:309
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:285
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:294
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:93
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:66
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:86
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:88
virtual R VisitStmt_(const AllocateNode *op, Args... args)
Definition: stmt_functor.h:90
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:73
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:97
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:87
R result_type
the result type of this functor
Definition: stmt_functor.h:64
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const LetStmtNode *op, Args... args)
Definition: stmt_functor.h:85
virtual R VisitStmt_(const BufferRealizeNode *op, Args... args)
Definition: stmt_functor.h:94
virtual R VisitStmt_(const AllocateConstNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const BlockNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:100
virtual R VisitStmt_(const BlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:80
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:46
StmtMutator that mutates the statements.
Definition: stmt_functor.h:170
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:180
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:229
Stmt VisitStmt_(const LetStmtNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
Stmt VisitStmt_(const BlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const BlockNode *op) override
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:207
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
Stmt VisitStmt_(const BufferRealizeNode *op) override
Stmt VisitStmt_(const AllocateNode *op) override
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:246
Stmt VisitStmt_(const AllocateConstNode *op) override
Stmt VisitStmt_(const ForNode *op) override
StmtVisitor.
Definition: stmt_functor.h:135
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const IfThenElseNode *op) override
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const AllocateConstNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:148
void VisitStmt_(const AllocateNode *op) override
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const EvaluateNode *op) override
void VisitStmt_(const BlockNode *op) override
void VisitStmt_(const LetStmtNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
void VisitStmt_(const BufferRealizeNode *op) override
void VisitStmt_(const BlockRealizeNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
Container of all statements.
Definition: stmt.h:64
a named variable in TIR
Definition: var.h:78
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:124
A While loop.
Definition: stmt.h:813
Defines the Functor data structures.
tvm::relax::Function Function
Definition: transform.h:42
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimFunc RenewDefs(const PrimFunc &func)
Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple ...
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:542
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, Optional< Array< String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
IndexMap Substitute(const IndexMap &index_map, std::function< Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
TIR statements.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:51
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:48
TIR expressions.
Functors for tir expressions.
TIR Function.