tvm
stmt_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TIR_STMT_FUNCTOR_H_
27 #define TVM_TIR_STMT_FUNCTOR_H_
28 
29 #include <tvm/node/functor.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/expr_functor.h>
32 #include <tvm/tir/function.h>
33 #include <tvm/tir/stmt.h>
34 
35 #include <unordered_map>
36 #include <utility>
37 
38 namespace tvm {
39 namespace tir {
45 template <typename FType>
47 
48 #define STMT_FUNCTOR_DEFAULT \
49  { return VisitStmtDefault_(op, std::forward<Args>(args)...); }
50 
51 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
52  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
53  return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
54  });
55 
56 template <typename R, typename... Args>
57 class StmtFunctor<R(const Stmt& n, Args... args)> {
58  private:
59  using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
60  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
61 
62  public:
64  using result_type = R;
66  virtual ~StmtFunctor() {}
73  R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
80  virtual R VisitStmt(const Stmt& n, Args... args) {
81  static FType vtable = InitVTable();
82  return vtable(n, this, std::forward<Args>(args)...);
83  }
84  // Functions that can be overriden by subclass
85  virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
86  virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
87  virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
88  virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89  virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90  virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
91  virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92  virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
93  virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
94  virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95  virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96  virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97  virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98  virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
99  virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100  virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
101  virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
102  virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
103  virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
104  virtual R VisitStmtDefault_(const Object* op, Args...) {
105  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
106  return R();
107  }
108 
109  private:
110  // initialize the vtable.
111  static FType InitVTable() {
112  FType vtable;
132  return vtable;
133  }
134 };
135 
136 #undef IR_STMT_FUNCTOR_DISPATCH
137 #undef STMT_FUNCTOR_DEFAULT
138 
142 class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
143  public:
144  using StmtFunctor::operator();
145 
146  protected:
147  using StmtFunctor::VisitStmt;
155  virtual void VisitExpr(const PrimExpr& e) {}
156  // statement visitor
157  void VisitStmt_(const AttrStmtNode* op) override;
158  void VisitStmt_(const IfThenElseNode* op) override;
159  void VisitStmt_(const LetStmtNode* op) override;
160  void VisitStmt_(const ForNode* op) override;
161  void VisitStmt_(const WhileNode* op) override;
162  void VisitStmt_(const AllocateNode* op) override;
163  void VisitStmt_(const AllocateConstNode* op) override;
164  void VisitStmt_(const DeclBufferNode* op) override;
165  void VisitStmt_(const StoreNode* op) override;
166  void VisitStmt_(const BufferStoreNode* op) override;
167  void VisitStmt_(const BufferRealizeNode* op) override;
168  void VisitStmt_(const AssertStmtNode* op) override;
169  void VisitStmt_(const ProducerStoreNode* op) override;
170  void VisitStmt_(const ProducerRealizeNode* op) override;
171  void VisitStmt_(const PrefetchNode* op) override;
172  void VisitStmt_(const SeqStmtNode* op) override;
173  void VisitStmt_(const EvaluateNode* op) override;
174  void VisitStmt_(const BlockNode* op) override;
175  void VisitStmt_(const BlockRealizeNode* op) override;
176 };
177 
181 class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
182  public:
192  allow_copy_on_write_ = true;
193  return VisitStmt(stmt);
194  }
195 
196  protected:
197  // We perform copy on write optimizations on the StmtMutator
198  // so that an unique copy of parent can be mutated inplace
199  // when some of its children changed.
200  // We only do such optimization for Stmt nests(instead of Exprs) for now
201  // as Stmt's parent state is more likely remain unchanged when one of
202  // its child block changes.
207  bool allow_copy_on_write_{false};
217  template <typename TNode>
218  ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
219  static_assert(std::is_base_of<StmtNode, TNode>::value,
220  "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
221  "nodes during the recursion. Because the child classes do not necessarily "
222  "check the Array, Expr and other structures during the visit, it is only safe to "
223  "call this function with StmtNodes for now. "
224  "Please create a new node directly in other cases.");
225  if (allow_copy_on_write_) {
226  // return the old node.
227  return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
228  } else {
229  // Make a new copy of the node.
230  // need to rely on the default copy constructor
231  return runtime::make_object<TNode>(*node);
232  }
233  }
240  Stmt VisitStmt(const Stmt& stmt) override {
241  if (allow_copy_on_write_ && !stmt.unique()) {
242  allow_copy_on_write_ = false;
243  Stmt ret = StmtFunctor::VisitStmt(stmt);
244  allow_copy_on_write_ = true;
245  return ret;
246  } else {
247  return StmtFunctor::VisitStmt(stmt);
248  }
249  }
257  virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
258  // statement visitor
259  Stmt VisitStmt_(const AttrStmtNode* op) override;
260  Stmt VisitStmt_(const IfThenElseNode* op) override;
261  Stmt VisitStmt_(const LetStmtNode* op) override;
262  Stmt VisitStmt_(const ForNode* op) override;
263  Stmt VisitStmt_(const WhileNode* op) override;
264  Stmt VisitStmt_(const AllocateNode* op) override;
265  Stmt VisitStmt_(const AllocateConstNode* op) override;
266  Stmt VisitStmt_(const DeclBufferNode* op) override;
267  Stmt VisitStmt_(const StoreNode* op) override;
268  Stmt VisitStmt_(const BufferStoreNode* op) override;
269  Stmt VisitStmt_(const BufferRealizeNode* op) override;
270  Stmt VisitStmt_(const AssertStmtNode* op) override;
271  Stmt VisitStmt_(const ProducerStoreNode* op) override;
272  Stmt VisitStmt_(const ProducerRealizeNode* op) override;
273  Stmt VisitStmt_(const PrefetchNode* op) override;
274  Stmt VisitStmt_(const SeqStmtNode* op) override;
275  Stmt VisitStmt_(const EvaluateNode* op) override;
276  Stmt VisitStmt_(const BlockNode* op) override;
277  Stmt VisitStmt_(const BlockRealizeNode* op) override;
290  Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
291  std::function<Stmt(const Stmt&)> fmutate = nullptr);
292 
293  // internal helper.
294  class Internal;
295 };
296 
300 class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
301  public:
302  using StmtVisitor::operator();
303  using ExprVisitor::operator();
304 
305  protected:
306  using ExprVisitor::VisitExpr;
307  using StmtVisitor::VisitStmt;
308 
309  void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
310 };
311 
315 class StmtExprMutator : public StmtMutator, public ExprMutator {
316  public:
317  using StmtMutator::operator();
318  using ExprMutator::operator();
319 
320  protected:
321  using ExprMutator::VisitExpr;
323 
324  PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
325 };
326 
342 TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder,
343  const runtime::PackedFunc& postorder,
344  Optional<Array<String>> only_enable = NullOpt);
345 
352 TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
353 
360 TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& var)> vmap);
361 
368 TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);
369 
376 TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap);
377 
385 template <typename T>
386 inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
387  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
388  auto it = value_map.find(var);
389  if (it != value_map.end()) return (*it).second;
390  return Optional<PrimExpr>(nullptr);
391  };
392  return Substitute(std::move(input), vmap);
393 }
394 
402 template <typename T>
403 inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
404  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
405  auto it = value_map.find(var.get());
406  if (it != value_map.end()) return (*it).second;
407  return Optional<PrimExpr>(nullptr);
408  };
409  return Substitute(std::move(input), vmap);
410 }
411 
419 TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
420  const std::function<bool(const ObjectRef&)>& fvisit);
421 
429 TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
430 } // namespace tir
431 } // namespace tvm
432 
433 #endif // TVM_TIR_STMT_FUNCTOR_H_
virtual R VisitStmt_(const ProducerRealizeNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmt_(const LetStmtNode *op, Args... args)
Definition: stmt_functor.h:85
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:404
A prefetch hint for a buffer.
Definition: stmt.h:1062
Declare a buffer that can be used in the body.
Definition: stmt.h:685
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:100
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:309
A custom smart pointer for Object.
Definition: object.h:358
StmtMutator that mutates the statements.
Definition: stmt_functor.h:181
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1312
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:257
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:300
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const AllocateConstNode *op, Args... args)
Definition: stmt_functor.h:91
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
ExprVisitor.
Definition: expr_functor.h:209
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:155
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:723
a named variable in TIR
Definition: var.h:88
IfThenElse statment.
Definition: stmt.h:820
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:324
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:46
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:66
virtual R VisitStmt_(const StoreNode *op, Args... args)
Definition: stmt_functor.h:93
TIR Function.
base class of all object containers.
Definition: object.h:167
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:191
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:240
PrimFunc RenewDefs(const PrimFunc &func)
Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple ...
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:51
Functors for tir expressions.
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:73
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:458
virtual R VisitStmt_(const ProducerStoreNode *op, Args... args)
Definition: stmt_functor.h:97
TIR statements.
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Managed reference to PrimFuncNode.
Definition: function.h:156
A While loop.
Definition: stmt.h:1022
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:80
iterator find(const K &key) const
Definition: map.h:1383
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:104
Container of all statements.
Definition: stmt.h:57
R result_type
the result type of this functor
Definition: stmt_functor.h:64
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:92
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:218
Allocate a buffer that can be used in body.
Definition: stmt.h:513
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:869
bool unique() const
Definition: object.h:550
virtual R VisitStmt_(const BufferRealizeNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmt_(const PrefetchNode *op, Args... args)
Definition: stmt_functor.h:99
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:128
A block is a basic schedule unit in TIR.
Definition: stmt.h:1228
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:86
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Allocate a buffer that can be used in body.
Definition: stmt.h:595
Defines the Functor data structures.
Base class of all object reference.
Definition: object.h:511
Store value to the buffer.
Definition: stmt.h:229
std::string GetTypeKey() const
Definition: object.h:180
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:87
Stmt Substitute(Stmt stmt, std::function< Optional< PrimExpr >(const Var &var)> vmap)
Substitute the var specified by vmap.
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:315
Store value to the high dimension buffer.
Definition: stmt.h:286
iterator end() const
Definition: map.h:1381
virtual R VisitStmt_(const BlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:103
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:88
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:166
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Stmt IRTransform(Stmt stmt, const runtime::PackedFunc &preorder, const runtime::PackedFunc &postorder, Optional< Array< String >> only_enable=NullOpt)
recursively visit the ir nodes in post DFS order, and transform it
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won&#39;t visit the children of the node.
A for loop, with poissible type annotations.
Definition: stmt.h:940
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:48
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:94
Let binding, bind var to value, then run body.
Definition: stmt.h:65
Reference to PrimExprNode.
Definition: expr.h:112
StmtVisitor.
Definition: stmt_functor.h:142
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:341
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:101
virtual R VisitStmt_(const BlockNode *op, Args... args)
Definition: stmt_functor.h:102
virtual R VisitStmt_(const AllocateNode *op, Args... args)
Definition: stmt_functor.h:90
ExprMutator that mutates expressions.
Definition: expr_functor.h:256