tvm
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_EXPR_FUNCTOR_H_
26 #define TVM_RELAX_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/relax/struct_info.h>
33 #include <tvm/relay/op.h>
34 #include <tvm/tir/function.h>
35 
36 #include <deque>
37 #include <string>
38 #include <unordered_map>
39 #include <utility>
40 #include <vector>
41 namespace tvm {
42 namespace relax {
43 
55 template <typename FType>
57 
58 // functions to be overriden.
59 #define EXPR_FUNCTOR_DEFAULT \
60  { return VisitExprDefault_(op, std::forward<Args>(args)...); }
61 
62 #define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
63  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
64  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
65  });
66 
67 #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \
68  { \
69  if (PY_FUNC != nullptr) \
70  PY_FUNC(N); \
71  else \
72  DEFAULT_FUNC; \
73  }
74 
75 #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \
76  { \
77  if (PY_FUNC != nullptr) { \
78  RET_TYPE ret = PY_FUNC(N); \
79  return ret; \
80  } else { \
81  return DEFAULT_FUNC; \
82  } \
83  }
84 
85 #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
86  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
87  if (self->PY_FUNC != nullptr) \
88  self->PY_FUNC(n); \
89  else \
90  self->VisitExpr_(static_cast<const OP*>(n.get())); \
91  });
92 
93 #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
94  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
95  if (self->PY_FUNC != nullptr) { \
96  Expr expr = self->PY_FUNC(n); \
97  return expr; \
98  } else { \
99  return self->VisitExpr_(static_cast<const OP*>(n.get())); \
100  } \
101  });
102 
103 #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \
104  post_order_vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
105  return self->VisitExprPostOrder_(static_cast<const OP*>(n.get())); \
106  });
107 
108 template <typename R, typename... Args>
109 class ExprFunctor<R(const Expr& n, Args...)> {
110  private:
111  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
112  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
113 
114  public:
116  using result_type = R;
118  virtual ~ExprFunctor() {}
125  R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
132  virtual R VisitExpr(const Expr& n, Args... args) {
133  ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
134  "have generated invalid data.";
135  static FType vtable = InitVTable();
136  return vtable(n, this, std::forward<Args>(args)...);
137  }
138  // Functions that can be overriden by subclass
139  // NOTE: cross dialect calls are invoked through global var
140  // We do not expect inline PrimFunc to appear in relax IR.
141  virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142  virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144  virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145  virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146  virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147  virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148  virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150  virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151  virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152  virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
154  virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
155  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
156  virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
157  virtual R VisitExprDefault_(const Object* op, Args...) {
158  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
159  throw;
160  }
161 
162  private:
163  // initialize the vtable.
164  static FType InitVTable() {
165  FType vtable;
166  // Set dispatch
183  return vtable;
184  }
185 };
186 
191 class ExprVisitor : public ExprFunctor<void(const Expr&)> {
192  public:
197  void VisitExpr(const Expr& expr) override;
198  // specific leaf level visitor functions
199  void VisitExpr_(const ConstantNode* op) override;
200  void VisitExpr_(const TupleNode* op) override;
201  void VisitExpr_(const VarNode* op) override;
202  void VisitExpr_(const DataflowVarNode* op) override;
203  void VisitExpr_(const ShapeExprNode* op) override;
204  void VisitExpr_(const ExternFuncNode* op) override;
205  void VisitExpr_(const GlobalVarNode* op) override;
206  void VisitExpr_(const FunctionNode* op) override;
207  void VisitExpr_(const CallNode* op) override;
208  void VisitExpr_(const SeqExprNode* op) override;
209  void VisitExpr_(const IfNode* op) override;
210  void VisitExpr_(const OpNode* op) override;
211  void VisitExpr_(const TupleGetItemNode* op) override;
212  void VisitExpr_(const PrimValueNode* op) override;
213  void VisitExpr_(const StringImmNode* op) override;
214  void VisitExpr_(const DataTypeImmNode* op) override;
215 
220  virtual void VisitBinding(const Binding& binding);
221  // specific leaf level visitor functions
222  virtual void VisitBinding_(const VarBindingNode* binding);
223  virtual void VisitBinding_(const MatchCastNode* binding);
224  // second level dispatching based on binding value type.
225  // these dispatching functions get called from first-level dispatch on VarBinding
226  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
227  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
228  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
229  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
230  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
231  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
232  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
233  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
234  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
235  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
236  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
237  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
238  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
239  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
240  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
241  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
246  virtual void VisitBindingBlock(const BindingBlock& block);
247  // specific leaf level visitor functions
248  virtual void VisitBindingBlock_(const BindingBlockNode* block);
249  virtual void VisitBindingBlock_(const DataflowBlockNode* block);
250 
256  virtual void VisitVarDef(const Var& var);
257 
272  virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);
273 
274  // specific leaf level visitor functions
275  virtual void VisitVarDef_(const VarNode* var);
276  virtual void VisitVarDef_(const DataflowVarNode* var);
277 
278  virtual void VisitSpan(const Span& span);
279  virtual void VisitPrimExpr(const PrimExpr& expr);
280 
281  private:
282  using TSelf = ExprVisitor;
283  using VisitBindingVTable =
284  tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
285  // initialize the vtable.
286  static VisitBindingVTable InitVisitBindingVTable();
296  class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
297  public:
298  explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);
299 
300  // Override defaults in struct info visitor.
301  void VisitStructInfoExprField(const Expr& expr) final;
302  void VisitStructInfoExprField(const PrimExpr& expr) final;
303  void VisitStructInfo_(const FuncStructInfoNode* op) final;
304 
305  private:
306  ExprVisitor* parent_;
307  };
308  // This visitor is not visible to child classes and only
309  // used to supported default visiting behavior.
310  DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
311 };
312 
313 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
314 
323 class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
324  public:
325  Expr VisitExpr(const Expr& expr) override;
326  Expr VisitExpr_(const ConstantNode* op) override;
327  Expr VisitExpr_(const TupleNode* op) override;
328  Expr VisitExpr_(const VarNode* op) override;
329  Expr VisitExpr_(const DataflowVarNode* op) override;
330  Expr VisitExpr_(const ShapeExprNode* op) override;
331  Expr VisitExpr_(const ExternFuncNode* op) override;
332  Expr VisitExpr_(const GlobalVarNode* op) override;
333  Expr VisitExpr_(const FunctionNode* op) override;
334  Expr VisitExpr_(const CallNode* op) override;
335  Expr VisitExpr_(const SeqExprNode* op) override;
336  Expr VisitExpr_(const IfNode* op) override;
337  Expr VisitExpr_(const OpNode* op) override;
338  Expr VisitExpr_(const TupleGetItemNode* op) override;
339  Expr VisitExpr_(const PrimValueNode* op) override;
340  Expr VisitExpr_(const StringImmNode* op) override;
341  Expr VisitExpr_(const DataTypeImmNode* op) override;
342 
349 
355  virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);
356 
372  virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);
373 
374  protected:
384  if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
385  return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
386  } else {
387  return true;
388  }
389  }
390 
391  private:
399  class DefaultStructInfoFieldMutator : public StructInfoMutator {
400  public:
401  explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);
402 
403  // Override defaults in struct info visitor.
404  Expr VisitStructInfoExprField(const Expr& expr) final;
405  PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
406  StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;
407 
408  private:
409  ExprMutatorBase* parent_;
410  };
411  // This visitor is not visible to child classes and only
412  // used to supported default visiting behavior.
413  DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
414 };
415 
423 class ExprMutator : public ExprMutatorBase {
424  public:
426 
428  Expr VisitExpr(const Expr& expr) override;
429  Expr VisitExpr_(const VarNode* op) override;
430  Expr VisitExpr_(const DataflowVarNode* op) override;
431  Expr VisitExpr_(const FunctionNode* op) override;
432  Expr VisitExpr_(const SeqExprNode* op) override;
433  Expr VisitExpr_(const IfNode* op) override;
434 
439  virtual void VisitBinding(const Binding& binding);
440  // specific leaf level visitor functions
441  virtual void VisitBinding_(const VarBindingNode* binding);
442  virtual void VisitBinding_(const MatchCastNode* binding);
443  // second level dispatching based on binding value type.
444  // these dispatching functions get called from first-level dispatch on VarBinding
445  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
446  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
447  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
448  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
449  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
450  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
451  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
452  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
453  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
454  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
455  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
456  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
457  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
458  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
459  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
460  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
466  virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*)
467  // specific leaf level visitor functions
470 
477  virtual Var VisitVarDef(const Var& var);
478  // specific leaf level visitor functions
479  virtual Var VisitVarDef_(const VarNode* var);
481 
482  protected:
494  void ReEmitBinding(const VarBindingNode* binding, Expr new_value);
495 
508  Expr VisitWithNewScope(const Expr& body_expr, Optional<Array<Var>> params = NullOpt);
509 
524  Expr VisitWithInnerScope(const Expr& body_expr);
525 
533 
540  template <typename T>
541  Expr VisitExprPostOrder_(const T* op) {
542  return builder_->Normalize(ExprMutator::VisitExpr_(op));
543  }
544 
553 
556 
558  std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
559 
560  private:
561  using TSelf = ExprMutator;
562  using VisitBindingVTable =
563  tvm::NodeFunctor<void(const ObjectRef& n, ExprMutator* self, const VarBindingNode* binding)>;
564  // initialize the vtable.
565  static VisitBindingVTable InitVisitBindingVTable();
566 };
567 
568 } // namespace relax
569 } // namespace tvm
570 #endif // TVM_RELAX_EXPR_FUNCTOR_H_
The utility for constructing Relax binding blocks.
Global variable that lives in the top-level module.
Definition: expr.h:456
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Primitive Op(builtin intrinsics)
Definition: op.h:58
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelayExprNode.
Definition: expr.h:442
Definition: source_map.h:120
Definition: expr.h:762
Definition: expr.h:784
Definition: expr.h:681
Definition: block_builder.h:264
static BlockBuilder Create(Optional< IRModule > ctx_mod)
Create a BlockBuilder.
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:138
Constant tensor.
Definition: expr.h:480
Represent a data type constant.
Definition: expr.h:628
Definition: expr.h:792
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:437
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const DataTypeImmNode *op, Args... args)
Definition: expr_functor.h:156
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:132
virtual R VisitExpr_(const ShapeExprNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const SeqExprNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:157
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:152
R result_type
the result type of this functor
Definition: expr_functor.h:116
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:151
virtual R VisitExpr_(const ExternFuncNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExpr_(const DataflowVarNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:155
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:125
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const PrimValueNode *op, Args... args)
Definition: expr_functor.h:154
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:153
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:141
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:118
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
A mutator works in unnormalized form.
Definition: expr_functor.h:323
Expr VisitExpr_(const TupleNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr_(const PrimValueNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block)
Mutate BindingBlock.
Expr VisitExpr_(const ShapeExprNode *op) override
virtual PrimExpr VisitPrimExpr(const PrimExpr &expr)
Used to visit the PrimExpr inside of expressions.
Expr VisitExpr_(const ExternFuncNode *op) override
Expr VisitExpr_(const DataTypeImmNode *op) override
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef &struct_info)
Check whether VisitExprDepStructInfoField change struct_info.
Definition: expr_functor.h:383
Expr VisitExpr_(const CallNode *op) override
Expr VisitExpr_(const OpNode *op) override
virtual StructInfo VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info that may recursively contain Expr/PrimExpr.
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const StringImmNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const FunctionNode *op) override
Expr VisitExpr_(const DataflowVarNode *op) override
A mutator works in normal form.
Definition: expr_functor.h:423
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
Expr VisitExpr_(const VarNode *op) override
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode *block)
void ReEmitBinding(const VarBindingNode *binding, Expr new_value)
Try to remit binding and bind it to a new_value.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
Var WithStructInfo(Var var, StructInfo struct_info)
Create a new var with specified struct_info if the original var's shape or type does not match with t...
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode *block)
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block) override
Generic dispatcher for binding blocks.
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
virtual Var VisitVarDef(const Var &var)
Generic dispatcher for rewriting the var definition site.
ExprMutator(Optional< IRModule > mod=NullOpt)
Definition: expr_functor.h:427
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
virtual Var VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
Expr VisitExprPostOrder_(const T *op)
Post-order rewrite a node and normalize.
Definition: expr_functor.h:541
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
Optional< Expr > LookupBinding(const Var &var)
Look up the value bound to a variable.
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
Expr VisitExpr_(const DataflowVarNode *op) override
Expr VisitWithInnerScope(const Expr &body_expr)
Rewrite the expr with a new scope, used in the branches of If.
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
Expr VisitWithNewScope(const Expr &body_expr, Optional< Array< Var >> params=NullOpt)
Rewrite the expr with a new scope, used in a Function's body.
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
BlockBuilder builder_
Internal block builder to emit bindings during rewriting.
Definition: expr_functor.h:555
virtual Var VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
std::unordered_map< Id, Var, ObjectPtrHash, ObjectPtrEqual > var_remap_
Remap a var to a new var in use-site.
Definition: expr_functor.h:558
virtual void VisitBinding_(const VarBindingNode *binding)
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:191
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const StringImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
void VisitExpr_(const SeqExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const IfNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr(const Expr &expr) override
Generic dispatcher for Expr.
virtual void VisitBindingBlock_(const DataflowBlockNode *block)
virtual void VisitSpan(const Span &span)
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
void VisitExpr_(const OpNode *op) override
virtual void VisitVarDef(const Var &var)
Generic dispatcher for visiting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
void VisitExpr_(const PrimValueNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
void VisitExpr_(const DataTypeImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
void VisitExpr_(const FunctionNode *op) override
void VisitExpr_(const VarNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
void VisitExpr_(const ExternFuncNode *op) override
virtual void VisitBindingBlock(const BindingBlock &block)
Generic dispatcher for binding blocks.
void VisitExpr_(const DataflowVarNode *op) override
virtual void VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
virtual void VisitPrimExpr(const PrimExpr &expr)
virtual void VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
void VisitExpr_(const ShapeExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
void VisitExpr_(const TupleNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
void VisitExpr_(const ConstantNode *op) override
virtual void VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info may recursively contain Expr/PrimExpr.
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
virtual void VisitBindingBlock_(const BindingBlockNode *block)
The extern function, which can represent packed function.
Definition: expr.h:1065
Structure information about function.
Definition: struct_info.h:303
A Relax function.
Definition: expr.h:950
Condition expression.
Definition: expr.h:878
Runtime-match the value to the struct info.
Definition: expr.h:700
PrimValue.
Definition: expr.h:534
A sequence of blocks followed by an expression.
Definition: expr.h:817
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:356
Represent a string literal constant.
Definition: expr.h:585
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:139
Base type of all structure information.
Definition: expr.h:110
A struct info visitor.
Definition: struct_info_functor.h:120
Managed reference to StructInfoNode.
Definition: expr.h:129
Get index-th field out of a tuple.
Definition: expr.h:282
Tuple container.
Definition: expr.h:219
Definition: expr.h:735
The variable class for all Relax bindings.
Definition: expr.h:389
Definition: expr.h:422
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
Primitive operators(builtin intrinsics).
Functors and visitors for struct info.
TIR Function.