tvm
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
26 #define TVM_RELAY_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
29 #include <tvm/relay/adt.h>
30 #include <tvm/relay/error.h>
31 #include <tvm/relay/expr.h>
32 #include <tvm/relay/function.h>
33 #include <tvm/relay/op.h>
34 
35 #include <deque>
36 #include <string>
37 #include <unordered_map>
38 #include <utility>
39 #include <vector>
40 
41 namespace tvm {
42 namespace relay {
43 
55 template <typename FType>
57 
58 // functions to be overriden.
59 #define EXPR_FUNCTOR_DEFAULT \
60  { return VisitExprDefault_(op, std::forward<Args>(args)...); }
61 
62 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
63  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
64  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
65  });
66 
67 template <typename R, typename... Args>
68 class ExprFunctor<R(const Expr& n, Args...)> {
69  private:
70  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
71  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
72 
73  public:
75  using result_type = R;
77  virtual ~ExprFunctor() {}
84  R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
91  virtual R VisitExpr(const Expr& n, Args... args) {
92  ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
93  "have generated invalid data.";
94  static FType vtable = InitVTable();
95  return vtable(n, this, std::forward<Args>(args)...);
96  }
97  // Functions that can be overriden by subclass
98  virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
99  virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
100  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
101  virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
102  virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
103  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
104  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
105  virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
106  virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
107  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
108  virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
109  virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
110  virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
111  virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
112  virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
113  virtual R VisitExprDefault_(const Object* op, Args...) {
114  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
115  throw;
116  }
117 
118  private:
119  // initialize the vtable.
120  static FType InitVTable() {
121  FType vtable;
122  // Set dispatch
138  return vtable;
139  }
140 };
141 
149 class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
150  public:
151  void VisitExpr(const Expr& expr) override;
152  void VisitExpr_(const VarNode* op) override;
153  void VisitExpr_(const GlobalVarNode* op) override;
154  void VisitExpr_(const ConstantNode* op) override;
155  void VisitExpr_(const TupleNode* op) override;
156  void VisitExpr_(const FunctionNode* op) override;
157  void VisitExpr_(const CallNode* op) override;
158  void VisitExpr_(const LetNode* op) override;
159  void VisitExpr_(const IfNode* op) override;
160  void VisitExpr_(const OpNode* op) override;
161  void VisitExpr_(const TupleGetItemNode* op) override;
162  void VisitExpr_(const RefCreateNode* op) override;
163  void VisitExpr_(const RefReadNode* op) override;
164  void VisitExpr_(const RefWriteNode* op) override;
165  void VisitExpr_(const ConstructorNode* op) override;
166  void VisitExpr_(const MatchNode* op) override;
167  virtual void VisitType(const Type& t);
168  virtual void VisitClause(const Clause& c);
169  virtual void VisitPattern(const Pattern& c);
170  virtual void VisitSpan(const Span& span);
171 
172  protected:
173  // Internal visiting counter
174  std::unordered_map<const Object*, size_t> visit_counter_;
175 };
176 
184 class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
185  public:
190  Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
191  Expr VisitExpr(const Expr& expr) override;
192  Expr VisitExpr_(const VarNode* op) override;
193  Expr VisitExpr_(const ConstantNode* op) override;
194  Expr VisitExpr_(const GlobalVarNode* op) override;
195  Expr VisitExpr_(const OpNode* op) override;
196  Expr VisitExpr_(const TupleNode* op) override;
197  Expr VisitExpr_(const FunctionNode* op) override;
198  Expr VisitExpr_(const CallNode* call_node) override;
199  Expr VisitExpr_(const LetNode* op) override;
200  Expr VisitExpr_(const IfNode* op) override;
201  Expr VisitExpr_(const TupleGetItemNode* op) override;
202  Expr VisitExpr_(const RefCreateNode* op) override;
203  Expr VisitExpr_(const RefReadNode* op) override;
204  Expr VisitExpr_(const RefWriteNode* op) override;
205  Expr VisitExpr_(const ConstructorNode* op) override;
206  Expr VisitExpr_(const MatchNode* op) override;
207 
215  virtual Type VisitType(const Type& t);
216  virtual Clause VisitClause(const Clause& c);
217  virtual Pattern VisitPattern(const Pattern& c);
218 
219  protected:
221  std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
222 };
223 
234  public:
235  using ::tvm::relay::ExprFunctor<void(const Expr& n)>::VisitExpr_;
236 
241  explicit MixedModeVisitor(int visit_limit = 1);
242 
244 
248  void VisitExpr(const Expr& expr) final;
249  void VisitExpr_(const CallNode* op) override;
250  void VisitExpr_(const TupleNode* op) override;
251  void VisitExpr_(const TupleGetItemNode* op) override;
252 
253  protected:
257  virtual void VisitLeaf(const Expr& expr);
262  virtual bool CheckVisited(const Expr& expr);
266  size_t visit_limit_;
267 };
268 
283  public:
284  using ::tvm::relay::ExprFunctor<Expr(const Expr&)>::VisitExpr_;
285 
286  MixedModeMutator(bool pre = false) : pre_{pre} {};
287  Expr VisitExpr(const Expr& expr) final;
288 
289  virtual Expr DispatchVisitExpr(const Expr& expr);
290  Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
291  Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
292  Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
301  virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; }
302  virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
303  virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
304 
305  protected:
306  bool pre_;
310  template <typename T>
311  Expr Rewrite(const T* op) {
312  Expr post = ExprMutator::VisitExpr_(op);
313  return Rewrite_(op, post);
314  }
315 
316  virtual void VisitLeaf(const Expr& expr);
317  virtual bool CheckVisited(const Expr& expr);
318 };
319 
320 #define RELAY_EXPR_REWRITER_DISPATCH(OP) \
321  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
322  return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
323  });
324 
325 #define EXPR_REWRITER_REWRITE_DEFAULT \
326  { return post; }
327 
339  private:
340  using TSelf = ExprRewriter;
341  using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>;
342 
343  public:
345  virtual ~ExprRewriter() {}
352  Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); }
359  virtual Expr Rewrite(const Expr& pre, const Expr& post) {
360  ICHECK(pre.defined());
361  static FType vtable = InitVTable();
362  return vtable(pre, this, post);
363  }
364  // Functions that can be overriden by subclass, should not recurse
365  virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
367  virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
368  virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
369  virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
370  virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
371  virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
372  virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
373  virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
374  virtual Expr Rewrite_(const TupleGetItemNode* pre,
375  const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
377  virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
378  virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
380  virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
381 
382  private:
383  // initialize the vtable.
384  static FType InitVTable() {
385  FType vtable;
386  // Set dispatch
402  return vtable;
403  }
404 };
405 
413 Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
414 
421 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
422 
426 struct v_info {
427  explicit v_info(Expr node_) : node{node_} {}
428  v_info(Expr node_, bool children_expanded_)
429  : node{node_}, children_expanded{children_expanded_} {};
431  bool children_expanded{false};
432 };
433 
452 template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
453 void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
454  FExpandExpr fexpand_expr) {
455  std::deque<v_info> stack;
456  auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
457  if (!fcheck_visited(expr)) {
458  stack.emplace_front(v_info(expr));
459  }
460  };
461 
462  fpush_to_stack(expr);
463  while (stack.size() > 0) {
464  v_info* front = &stack.front();
465  if (fcheck_visited(front->node)) {
466  stack.pop_front();
467  } else if (front->children_expanded) {
468  fvisit_leaf(front->node);
469  // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
470  stack.pop_front();
471  } else {
472  front->children_expanded = true;
473  for (auto e : fexpand_expr(front->node)) {
474  fpush_to_stack(e);
475  }
476  }
477  }
478 }
479 
480 template <typename FCheckVisited, typename FVisitLeaf>
481 void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
482  auto fexpand_expr = [](const Expr& expr) {
483  std::vector<Expr> result;
484  if (const CallNode* op = expr.as<CallNode>()) {
485  if (op->op == Op::Get("call_lowered")) {
486  // Ignore the intermediate tuple since this is purely a calling-convention detail
487  const auto* tuple_args = op->args[1].as<TupleNode>();
488  ICHECK(tuple_args)
489  << "Expected second arg to call_lowered to be a Tuple of input arguments.";
490  for (auto it = tuple_args->fields.rbegin(); it != tuple_args->fields.rend(); ++it) {
491  result.push_back(*it);
492  }
493  result.push_back(op->args[0]);
494  } else {
495  for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
496  result.push_back(*it);
497  }
498  }
499  result.push_back(op->op);
500  } else if (const TupleNode* op = expr.as<TupleNode>()) {
501  for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
502  result.push_back(*it);
503  }
504  } else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
505  result.push_back(op->tuple);
506  }
507  return result;
508  };
509  ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
510 }
511 
512 void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
513  std::function<void(const LetNode*)> post_visit);
514 
515 } // namespace relay
516 } // namespace tvm
517 #endif // TVM_RELAY_EXPR_FUNCTOR_H_
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Global variable that lives in the top-level module.
Definition: expr.h:456
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Primitive Op(builtin intrinsics)
Definition: op.h:58
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
Managed reference to RelayExprNode.
Definition: expr.h:442
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Call container.
Definition: expr.h:282
Definition: adt.h:253
Constant tensor type.
Definition: expr.h:71
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:107
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:98
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:106
R result_type
the result type of this functor
Definition: expr_functor.h:75
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:101
virtual R VisitExpr_(const RefReadNode *op, Args... args)
Definition: expr_functor.h:109
virtual R VisitExpr_(const ConstructorNode *op, Args... args)
Definition: expr_functor.h:111
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:84
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:104
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:99
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:103
virtual R VisitExpr_(const RefWriteNode *op, Args... args)
Definition: expr_functor.h:110
virtual R VisitExpr_(const RefCreateNode *op, Args... args)
Definition: expr_functor.h:108
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:105
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:91
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:113
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:77
virtual R VisitExpr_(const MatchNode *op, Args... args)
Definition: expr_functor.h:112
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:100
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:102
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
A wrapper around ExprFunctor which functionally updates the AST.
Definition: expr_functor.h:184
Expr VisitExpr_(const RefWriteNode *op) override
Expr VisitExpr_(const OpNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr_(const RefReadNode *op) override
virtual Pattern VisitPattern(const Pattern &c)
Expr VisitExpr_(const TupleNode *op) override
std::unordered_map< Expr, Expr, ObjectPtrHash, ObjectPtrEqual > memo_
Internal map used for memoization.
Definition: expr_functor.h:221
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const MatchNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
virtual Type VisitType(const Type &t)
Used to visit the types inside of expressions.
Expr Mutate(const Expr &expr)
Mutate is alias for VisitExpr.
Definition: expr_functor.h:190
Expr VisitExpr_(const RefCreateNode *op) override
Expr VisitExpr_(const ConstructorNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual Clause VisitClause(const Clause &c)
Expr VisitExpr_(const LetNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const CallNode *call_node) override
A non-iterating Expression Rewriter.
Definition: expr_functor.h:338
virtual Expr Rewrite_(const MatchNode *pre, const Expr &post)
Definition: expr_functor.h:380
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Definition: expr_functor.h:368
virtual Expr Rewrite(const Expr &pre, const Expr &post)
The functor call.
Definition: expr_functor.h:359
virtual Expr Rewrite_(const RefReadNode *pre, const Expr &post)
Definition: expr_functor.h:377
virtual Expr Rewrite_(const LetNode *pre, const Expr &post)
Definition: expr_functor.h:371
virtual Expr Rewrite_(const IfNode *pre, const Expr &post)
Definition: expr_functor.h:372
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:374
virtual Expr Rewrite_(const RefCreateNode *pre, const Expr &post)
Definition: expr_functor.h:376
virtual Expr Rewrite_(const RefWriteNode *pre, const Expr &post)
Definition: expr_functor.h:378
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:370
virtual Expr Rewrite_(const VarNode *pre, const Expr &post)
Definition: expr_functor.h:365
virtual Expr Rewrite_(const OpNode *pre, const Expr &post)
Definition: expr_functor.h:373
virtual Expr Rewrite_(const FunctionNode *pre, const Expr &post)
Definition: expr_functor.h:369
Expr operator()(const Expr &pre, const Expr &post)
Same as call.
Definition: expr_functor.h:352
virtual Expr Rewrite_(const ConstantNode *pre, const Expr &post)
Definition: expr_functor.h:367
virtual ~ExprRewriter()
virtual destructor
Definition: expr_functor.h:345
virtual Expr Rewrite_(const GlobalVarNode *pre, const Expr &post)
Definition: expr_functor.h:366
virtual Expr Rewrite_(const ConstructorNode *pre, const Expr &post)
Definition: expr_functor.h:379
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:149
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const ConstantNode *op) override
void VisitExpr_(const FunctionNode *op) override
std::unordered_map< const Object *, size_t > visit_counter_
Definition: expr_functor.h:174
void VisitExpr_(const IfNode *op) override
virtual void VisitSpan(const Span &span)
void VisitExpr_(const TupleNode *op) override
void VisitExpr_(const VarNode *op) override
void VisitExpr(const Expr &expr) override
void VisitExpr_(const RefCreateNode *op) override
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const RefReadNode *op) override
virtual void VisitPattern(const Pattern &c)
virtual void VisitClause(const Clause &c)
void VisitExpr_(const ConstructorNode *op) override
void VisitExpr_(const MatchNode *op) override
void VisitExpr_(const OpNode *op) override
void VisitExpr_(const RefWriteNode *op) override
virtual void VisitType(const Type &t)
void VisitExpr_(const LetNode *op) override
Relay Function container.
Definition: function.h:39
container of If
Definition: expr.h:491
A binding of a sub-network.
Definition: expr.h:404
Match container node.
Definition: adt.h:277
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Definition: expr_functor.h:282
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:303
Expr VisitExpr(const Expr &expr) final
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be able to re...
Definition: expr_functor.h:301
Expr Rewrite(const T *op)
Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a post node with changed inputs.
Definition: expr_functor.h:311
Expr VisitExpr_(const TupleGetItemNode *op) final
Definition: expr_functor.h:292
bool pre_
Definition: expr_functor.h:306
Expr VisitExpr_(const TupleNode *op) final
Definition: expr_functor.h:290
Expr VisitExpr_(const CallNode *call_node) final
Definition: expr_functor.h:291
virtual Expr DispatchVisitExpr(const Expr &expr)
MixedModeMutator(bool pre=false)
Definition: expr_functor.h:286
virtual void VisitLeaf(const Expr &expr)
virtual bool CheckVisited(const Expr &expr)
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:302
A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
Definition: expr_functor.h:233
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const VarNode *op) override
void VisitExpr_(const TupleGetItemNode *op) override
virtual bool CheckVisited(const Expr &expr)
A function to determine if an expression has already been visited or needs to be re-visited.
void VisitExpr_(const TupleNode *op) override
size_t visit_limit_
The max number of times to visit a node.
Definition: expr_functor.h:266
virtual void VisitLeaf(const Expr &expr)
A function to apply when reaching a leaf of the graph non-recursively.
MixedModeVisitor(int visit_limit=1)
The constructor of MixedModeVisitor.
void VisitExpr(const Expr &expr) final
VisitExpr is finalized to preserve call expansion of dataflow regions.
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
Definition: expr.h:608
Definition: expr.h:658
Definition: expr.h:708
Definition: expr.h:554
Tuple container.
Definition: expr.h:123
Container for Var.
Definition: expr.h:188
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Defines the Functor data structures.
void ExpandANormalForm(const LetNode *op, std::function< void(const LetNode *)> pre_visit, std::function< void(const LetNode *)> post_visit)
tvm::RelayExpr Expr
Definition: expr.h:54
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, FExpandExpr fexpand_expr)
A function to iteratively traverse dataflow regions of a graph.
Definition: expr_functor.h:453
Expr PostOrderRewrite(const Expr &expr, ExprRewriter *rewriter)
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:532
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Algebraic data types for Relay.
Relay expression language.
#define RELAY_EXPR_REWRITER_DISPATCH(OP)
Definition: expr_functor.h:320
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
#define EXPR_REWRITER_REWRITE_DEFAULT
Definition: expr_functor.h:325
Relay Function.
Primitive operators(builtin intrinsics).
A struct to keep info of traversed expr in ExpandDataflow function.
Definition: expr_functor.h:426
v_info(Expr node_)
Definition: expr_functor.h:427
Expr node
Definition: expr_functor.h:430
bool children_expanded
Definition: expr_functor.h:431
v_info(Expr node_, bool children_expanded_)
Definition: expr_functor.h:428