tvm
stmt_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TIR_STMT_FUNCTOR_H_
27 #define TVM_TIR_STMT_FUNCTOR_H_
28 
29 #include <tvm/node/functor.h>
30 #include <tvm/tirx/expr.h>
31 #include <tvm/tirx/expr_functor.h>
32 #include <tvm/tirx/function.h>
33 #include <tvm/tirx/stmt.h>
34 
35 #include <unordered_map>
36 #include <utility>
37 
38 namespace tvm {
39 namespace tirx {
45 template <typename FType>
47 
48 #define STMT_FUNCTOR_DEFAULT \
49  { \
50  return VisitStmtDefault_(op, std::forward<Args>(args)...); \
51  }
52 
53 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
54  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
55  return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
56  });
57 
58 template <typename R, typename... Args>
59 class StmtFunctor<R(const Stmt& n, Args... args)> {
60  private:
61  using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
62  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
63 
64  public:
66  using result_type = R;
68  virtual ~StmtFunctor() {}
75  R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
82  virtual R VisitStmt(const Stmt& n, Args... args) {
83  static FType vtable = InitVTable();
84  return vtable(n, this, std::forward<Args>(args)...);
85  }
86  // Functions that can be overriden by subclass
87  virtual R VisitStmt_(const BindNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
88  virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89  virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90  virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
91  virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92  virtual R VisitStmt_(const AllocBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
93  virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
94  virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95  virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96  virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97  virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98  virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
99  virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100  virtual R VisitStmtDefault_(const Object* op, Args...) {
101  TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey();
102  TVM_FFI_UNREACHABLE();
103  }
104 
105  private:
106  // initialize the vtable.
107  static FType InitVTable() {
108  FType vtable;
122  vtable.Finalize();
123  return vtable;
124  }
125 };
126 
127 #undef IR_STMT_FUNCTOR_DISPATCH
128 #undef STMT_FUNCTOR_DEFAULT
129 
133 class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
134  public:
135  using StmtFunctor::operator();
136 
137  protected:
138  using StmtFunctor::VisitStmt;
146  virtual void VisitExpr(const PrimExpr& e) {}
154  virtual void VisitBufferDef(const Buffer& buffer, bool alloc_data);
160  virtual void VisitBufferUse(const Buffer& buffer);
161  // statement visitor
162  void VisitStmt_(const BindNode* op) override;
163  void VisitStmt_(const AttrStmtNode* op) override;
164  void VisitStmt_(const IfThenElseNode* op) override;
165  void VisitStmt_(const ForNode* op) override;
166  void VisitStmt_(const WhileNode* op) override;
167  void VisitStmt_(const AllocBufferNode* op) override;
168  void VisitStmt_(const DeclBufferNode* op) override;
169  void VisitStmt_(const BufferStoreNode* op) override;
170  void VisitStmt_(const AssertStmtNode* op) override;
171  void VisitStmt_(const SeqStmtNode* op) override;
172  void VisitStmt_(const EvaluateNode* op) override;
173  void VisitStmt_(const SBlockNode* op) override;
174  void VisitStmt_(const SBlockRealizeNode* op) override;
175 };
176 
180 class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
181  public:
191  allow_copy_on_write_ = true;
192  return VisitStmt(stmt);
193  }
194 
195  protected:
197  ffi::Map<Buffer, Buffer> buffer_remap_;
198  // We perform copy on write optimizations on the StmtMutator
199  // so that an unique copy of parent can be mutated inplace
200  // when some of its children changed.
201  // We only do such optimization for Stmt nests(instead of Exprs) for now
202  // as Stmt's parent state is more likely remain unchanged when one of
203  // its child block changes.
208  bool allow_copy_on_write_{false};
218  template <typename TNode>
219  ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
220  static_assert(std::is_base_of<StmtNode, TNode>::value,
221  "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
222  "nodes during the recursion. Because the child classes do not necessarily "
223  "check the Array, Expr and other structures during the visit, it is only safe to "
224  "call this function with StmtNodes for now. "
225  "Please create a new node directly in other cases.");
226  if (allow_copy_on_write_) {
227  // return the old node.
228  return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
229  } else {
230  // Make a new copy of the node.
231  // need to rely on the default copy constructor
232  return ffi::make_object<TNode>(*node);
233  }
234  }
241  Stmt VisitStmt(const Stmt& stmt) override {
242  if (allow_copy_on_write_ && !stmt.unique()) {
243  allow_copy_on_write_ = false;
244  Stmt ret = StmtFunctor::VisitStmt(stmt);
245  allow_copy_on_write_ = true;
246  return ret;
247  } else {
248  return StmtFunctor::VisitStmt(stmt);
249  }
250  }
258  virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
267  virtual Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data);
274  virtual Buffer VisitBufferUse(const Buffer& buffer);
275  // statement visitor
276  Stmt VisitStmt_(const BindNode* op) override;
277  Stmt VisitStmt_(const AttrStmtNode* op) override;
278  Stmt VisitStmt_(const IfThenElseNode* op) override;
279  Stmt VisitStmt_(const ForNode* op) override;
280  Stmt VisitStmt_(const WhileNode* op) override;
281  Stmt VisitStmt_(const AllocBufferNode* op) override;
282  Stmt VisitStmt_(const DeclBufferNode* op) override;
283  Stmt VisitStmt_(const BufferStoreNode* op) override;
284  Stmt VisitStmt_(const AssertStmtNode* op) override;
285  Stmt VisitStmt_(const SeqStmtNode* op) override;
286  Stmt VisitStmt_(const EvaluateNode* op) override;
287  Stmt VisitStmt_(const SBlockNode* op) override;
288  Stmt VisitStmt_(const SBlockRealizeNode* op) override;
301  Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
302  std::function<Stmt(const Stmt&)> fmutate = nullptr);
303 
304  // internal helper.
305  class Internal;
306 };
307 
311 class StmtExprVisitor : public ExprVisitor, public StmtVisitor {
312  public:
313  using StmtVisitor::operator();
314  using ExprVisitor::operator();
315 
316  protected:
317  using ExprVisitor::VisitExpr;
319  using StmtVisitor::VisitStmt;
320 
321  void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
322  void VisitExpr_(const BufferLoadNode* op) override;
323 };
324 
328 class StmtExprMutator : public ExprMutator, public StmtMutator {
329  public:
330  using StmtMutator::operator();
331  using ExprMutator::operator();
332 
333  protected:
334  using ExprMutator::VisitExpr;
337 
338  PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
339  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
340 };
341 
357 TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder,
358  ffi::Optional<ffi::Array<ffi::String>> only_enable = std::nullopt);
359 
366 TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
367 
374 TVM_DLL Stmt Substitute(Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap);
375 
383  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap);
384 
391 template <typename T>
392 ffi::Array<T> Substitute(const ffi::Array<T>& arr,
393  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap) {
394  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
395 }
396 
403 inline Range Substitute(const Range& range,
404  std::function<ffi::Optional<PrimExpr>(const Var& var)> vmap) {
405  return Range::FromMinExtent(Substitute(range->min, vmap), Substitute(range->extent, vmap));
406 }
407 
419 template <typename Obj>
420 auto Substitute(Obj&& obj, const ffi::Map<Var, PrimExpr>& vmap) {
421  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> { return vmap.Get(var); };
422  return Substitute(std::forward<Obj>(obj), func);
423 }
424 
434 template <typename Obj, typename Expr,
435  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
436 auto Substitute(Obj&& obj, const ffi::Map<Var, Expr>& vmap) {
437  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
438  if (auto opt = vmap.Get(var)) {
439  return opt.value();
440  } else {
441  return std::nullopt;
442  }
443  };
444  return Substitute(std::forward<Obj>(obj), func);
445 }
446 
456 template <typename Obj, typename Expr,
457  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
458 auto Substitute(Obj&& obj, const std::unordered_map<const VarNode*, Expr>& vmap) {
459  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
460  if (auto it = vmap.find(var.get()); it != vmap.end()) {
461  return it->second;
462  } else {
463  return std::nullopt;
464  }
465  };
466  return Substitute(std::forward<Obj>(obj), func);
467 }
468 
478 template <typename Obj, typename Expr, typename Hasher, typename EqualityChecker,
479  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
480 auto Substitute(Obj&& obj, const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
481  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
482  if (auto it = vmap.find(var); it != vmap.end()) {
483  return it->second;
484  } else {
485  return std::nullopt;
486  }
487  };
488  return Substitute(std::forward<Obj>(obj), func);
489 }
490 
500 template <typename Obj, typename Expr,
501  typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
502 auto Substitute(Obj&& obj, const std::unordered_map<IterVar, Expr>& iter_vmap) {
503  std::unordered_map<const VarNode*, PrimExpr> vmap;
504  for (const auto& [iter_var, expr] : iter_vmap) {
505  vmap[iter_var->var.get()] = expr;
506  }
507 
508  auto func = [&vmap](const Var& var) -> ffi::Optional<PrimExpr> {
509  if (auto it = vmap.find(var.get()); it != vmap.end()) {
510  return it->second;
511  } else {
512  return std::nullopt;
513  }
514  };
515  return Substitute(std::forward<Obj>(obj), func);
516 }
517 
529  Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap);
530 
542  PrimExpr expr, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap);
543 
551 TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
552  const std::function<bool(const ObjectRef&)>& fvisit);
553 
564 template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
565 bool ContainsNode(const Stmt& stmt) {
566  struct Visitor : StmtVisitor {
567  // Early bail-out, if we already found the node.
568  void VisitStmt(const Stmt& stmt) final {
569  if (contains_node) {
570  return;
571  }
572  StmtVisitor::VisitStmt(stmt);
573  }
574 
575  void VisitStmt_(const Node* block) override { contains_node = true; }
576 
577  bool contains_node{false};
578  };
579 
580  Visitor visitor;
581  visitor(stmt);
582  return visitor.contains_node;
583 }
584 
585 } // namespace tirx
586 } // namespace tvm
587 
588 #endif // TVM_TIR_STMT_FUNCTOR_H_
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer and declare it in scope.
Definition: stmt.h:259
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:159
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:115
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:77
Load value from the high dimension buffer.
Definition: expr.h:532
Store value to the high dimension buffer.
Definition: stmt.h:201
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Declare a buffer that can be used in the body.
Definition: stmt.h:238
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:336
ExprMutator that mutates expressions.
Definition: expr_functor.h:253
PrimExpr VisitExpr_(const VarNode *op) override
ExprVisitor.
Definition: expr_functor.h:208
void VisitExpr_(const VarNode *op) override
A for loop, with possible type annotations.
Definition: stmt.h:586
IfThenElse statement.
Definition: stmt.h:516
A block is a basic schedule unit in TIR.
Definition: stmt.h:799
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:864
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:311
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:328
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:338
PrimExpr VisitExpr_(const BufferLoadNode *op) override
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:311
void VisitExpr_(const BufferLoadNode *op) override
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:321
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:94
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:68
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:75
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const SBlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:90
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const SBlockNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:100
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:82
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:88
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:97
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:93
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:95
R result_type
the result type of this functor
Definition: stmt_functor.h:66
virtual R VisitStmt_(const AllocBufferNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const BindNode *op, Args... args)
Definition: stmt_functor.h:87
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:46
StmtMutator that mutates the statements.
Definition: stmt_functor.h:180
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:190
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt_(const SBlockNode *op) override
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const SBlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual Buffer VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site. Visits shape/strides/elem_offset via VisitExpr. If any field changes...
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:219
ffi::Map< Buffer, Buffer > buffer_remap_
Map from old buffer to new buffer, populated by VisitBufferDef.
Definition: stmt_functor.h:197
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:241
Stmt VisitStmt_(const BindNode *op) override
Stmt VisitStmt_(const ForNode *op) override
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
virtual Buffer VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
Stmt VisitStmt_(const AllocBufferNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:258
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
StmtVisitor.
Definition: stmt_functor.h:133
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const BindNode *op) override
void VisitStmt_(const SBlockNode *op) override
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
void VisitStmt_(const SBlockRealizeNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
virtual void VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const EvaluateNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:146
void VisitStmt_(const IfThenElseNode *op) override
virtual void VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock alloc_buffers). Visits buffer shape,...
void VisitStmt_(const AllocBufferNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
Container of all statements.
Definition: stmt.h:65
a named variable in TIR
Definition: var.h:76
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:123
A While loop.
Definition: stmt.h:661
Defines the Functor data structures.
tvm::relax::Function Function
Definition: transform.h:38
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
IndexMap Substitute(const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:565
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:53
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:48
TIR expressions.
Functors for tirx expressions.
TIR Function.
TIR statements.