tvm
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_EXPR_FUNCTOR_H_
26 #define TVM_RELAX_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/relax/struct_info.h>
33 #include <tvm/tirx/function.h>
34 
35 #include <unordered_map>
36 #include <utility>
37 namespace tvm {
38 namespace relax {
39 
51 template <typename FType>
53 
54 // functions to be overriden.
55 #define EXPR_FUNCTOR_DEFAULT \
56  { \
57  return VisitExprDefault_(op, std::forward<Args>(args)...); \
58  }
59 
60 #define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
61  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
62  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
63  });
64 
65 #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \
66  { \
67  if (PY_FUNC != nullptr) \
68  PY_FUNC(N); \
69  else \
70  DEFAULT_FUNC; \
71  }
72 
73 #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \
74  { \
75  if (PY_FUNC != nullptr) { \
76  RET_TYPE ret = PY_FUNC(N).cast<RET_TYPE>(); \
77  return ret; \
78  } else { \
79  return DEFAULT_FUNC; \
80  } \
81  }
82 
83 #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
84  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
85  if (self->PY_FUNC != nullptr) \
86  self->PY_FUNC(n); \
87  else \
88  self->VisitExpr_(static_cast<const OP*>(n.get())); \
89  });
90 
91 #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
92  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
93  if (self->PY_FUNC != nullptr) { \
94  Expr expr = self->PY_FUNC(n).cast<Expr>(); \
95  return expr; \
96  } else { \
97  return self->VisitExpr_(static_cast<const OP*>(n.get())); \
98  } \
99  });
100 
101 #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \
102  post_order_vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
103  return self->VisitExprPostOrder_(static_cast<const OP*>(n.get())); \
104  });
105 
106 template <typename R, typename... Args>
107 class ExprFunctor<R(const Expr& n, Args...)> {
108  private:
109  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
110  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
111 
112  public:
114  using result_type = R;
116  virtual ~ExprFunctor() {}
123  R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
130  virtual R VisitExpr(const Expr& n, Args... args) {
131  TVM_FFI_ICHECK(n.defined())
132  << "Found null pointer node while traversing AST. The previous pass may "
133  "have generated invalid data.";
134  static FType vtable = InitVTable();
135  return vtable(n, this, std::forward<Args>(args)...);
136  }
137  // Functions that can be overriden by subclass
138  // NOTE: cross dialect calls are invoked through global var
139  // We do not expect inline PrimFunc to appear in relax IR.
140  virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141  virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143  virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144  virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145  virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146  virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147  virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149  virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150  virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151  virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153  virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
154  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
155  virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
156  virtual R VisitExprDefault_(const Object* op, Args...) {
157  TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey();
158  throw;
159  }
160 
161  private:
162  // initialize the vtable.
163  static FType InitVTable() {
164  FType vtable;
165  // Set dispatch
182  vtable.Finalize();
183  return vtable;
184  }
185 };
186 
191 class ExprVisitor : public ExprFunctor<void(const Expr&)> {
192  public:
197  void VisitExpr(const Expr& expr) override;
198  // specific leaf level visitor functions
199  void VisitExpr_(const ConstantNode* op) override;
200  void VisitExpr_(const TupleNode* op) override;
201  void VisitExpr_(const VarNode* op) override;
202  void VisitExpr_(const DataflowVarNode* op) override;
203  void VisitExpr_(const ShapeExprNode* op) override;
204  void VisitExpr_(const ExternFuncNode* op) override;
205  void VisitExpr_(const GlobalVarNode* op) override;
206  void VisitExpr_(const FunctionNode* op) override;
207  void VisitExpr_(const CallNode* op) override;
208  void VisitExpr_(const SeqExprNode* op) override;
209  void VisitExpr_(const IfNode* op) override;
210  void VisitExpr_(const OpNode* op) override;
211  void VisitExpr_(const TupleGetItemNode* op) override;
212  void VisitExpr_(const PrimValueNode* op) override;
213  void VisitExpr_(const StringImmNode* op) override;
214  void VisitExpr_(const DataTypeImmNode* op) override;
215 
220  virtual void VisitBinding(const Binding& binding);
221  // specific leaf level visitor functions
222  virtual void VisitBinding_(const VarBindingNode* binding);
223  virtual void VisitBinding_(const MatchCastNode* binding);
224  // second level dispatching based on binding value type.
225  // these dispatching functions get called from first-level dispatch on VarBinding
226  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
227  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
228  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
229  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
230  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
231  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
232  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
233  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
234  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
235  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
236  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
237  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
238  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
239  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
240  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
241  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
246  virtual void VisitBindingBlock(const BindingBlock& block);
247  // specific leaf level visitor functions
248  virtual void VisitBindingBlock_(const BindingBlockNode* block);
249  virtual void VisitBindingBlock_(const DataflowBlockNode* block);
250 
256  virtual void VisitVarDef(const Var& var);
257 
272  virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);
273 
274  // specific leaf level visitor functions
275  virtual void VisitVarDef_(const VarNode* var);
276  virtual void VisitVarDef_(const DataflowVarNode* var);
277 
278  virtual void VisitSpan(const Span& span);
279  virtual void VisitPrimExpr(const PrimExpr& expr);
280 
281  private:
282  using TSelf = ExprVisitor;
283  using VisitBindingVTable =
284  tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
285  // initialize the vtable.
286  static VisitBindingVTable InitVisitBindingVTable();
296  class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
297  public:
298  explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);
299 
300  // Override defaults in struct info visitor.
301  void VisitStructInfoExprField(const Expr& expr) final;
302  void VisitStructInfoExprField(const PrimExpr& expr) final;
303  void VisitStructInfo_(const FuncStructInfoNode* op) final;
304 
305  private:
306  ExprVisitor* parent_;
307  };
308  // This visitor is not visible to child classes and only
309  // used to supported default visiting behavior.
310  DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
311 };
312 
313 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
314 
323 class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
324  public:
325  Expr VisitExpr(const Expr& expr) override;
326  Expr VisitExpr_(const ConstantNode* op) override;
327  Expr VisitExpr_(const TupleNode* op) override;
328  Expr VisitExpr_(const VarNode* op) override;
329  Expr VisitExpr_(const DataflowVarNode* op) override;
330  Expr VisitExpr_(const ShapeExprNode* op) override;
331  Expr VisitExpr_(const ExternFuncNode* op) override;
332  Expr VisitExpr_(const GlobalVarNode* op) override;
333  Expr VisitExpr_(const FunctionNode* op) override;
334  Expr VisitExpr_(const CallNode* op) override;
335  Expr VisitExpr_(const SeqExprNode* op) override;
336  Expr VisitExpr_(const IfNode* op) override;
337  Expr VisitExpr_(const OpNode* op) override;
338  Expr VisitExpr_(const TupleGetItemNode* op) override;
339  Expr VisitExpr_(const PrimValueNode* op) override;
340  Expr VisitExpr_(const StringImmNode* op) override;
341  Expr VisitExpr_(const DataTypeImmNode* op) override;
342 
349 
355  virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);
356 
372  virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);
373 
374  protected:
383  bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) {
384  if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
385  return this->VisitExprDepStructInfoField(ffi::GetRef<StructInfo>(sinfo)).same_as(struct_info);
386  } else {
387  return true;
388  }
389  }
390 
391  private:
399  class DefaultStructInfoFieldMutator : public StructInfoMutator {
400  public:
401  explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);
402 
403  // Override defaults in struct info visitor.
404  Expr VisitStructInfoExprField(const Expr& expr) final;
405  PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
406  StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;
407 
408  private:
409  ExprMutatorBase* parent_;
410  };
411  // This visitor is not visible to child classes and only
412  // used to supported default visiting behavior.
413  DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
414 };
415 
423 class ExprMutator : public ExprMutatorBase {
424  public:
426 
427  ExprMutator(ffi::Optional<IRModule> mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); }
428  Expr VisitExpr(const Expr& expr) override;
429  Expr VisitExpr_(const VarNode* op) override;
430  Expr VisitExpr_(const DataflowVarNode* op) override;
431  Expr VisitExpr_(const FunctionNode* op) override;
432  Expr VisitExpr_(const SeqExprNode* op) override;
433  Expr VisitExpr_(const IfNode* op) override;
434 
439  virtual void VisitBinding(const Binding& binding);
440  // specific leaf level visitor functions
441  virtual void VisitBinding_(const VarBindingNode* binding);
442  virtual void VisitBinding_(const MatchCastNode* binding);
443  // second level dispatching based on binding value type.
444  // these dispatching functions get called from first-level dispatch on VarBinding
445  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
446  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
447  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
448  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
449  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
450  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
451  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
452  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
453  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
454  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
455  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
456  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
457  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
458  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
459  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
460  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
466  virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*)
467  // specific leaf level visitor functions
470 
477  virtual Var VisitVarDef(const Var& var);
478  // specific leaf level visitor functions
479  virtual Var VisitVarDef_(const VarNode* var);
481 
482  protected:
494  void ReEmitBinding(const VarBindingNode* binding, Expr new_value);
495 
508  Expr VisitWithNewScope(const Expr& body_expr,
509  ffi::Optional<ffi::Array<Var>> params = std::nullopt);
510 
525  Expr VisitWithInnerScope(const Expr& body_expr);
526 
533  ffi::Optional<Expr> LookupBinding(const Var& var);
534 
541  template <typename T>
542  Expr VisitExprPostOrder_(const T* op) {
543  return builder_->Normalize(ExprMutator::VisitExpr_(op));
544  }
545 
554 
557 
559  std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
560 
561  private:
562  using TSelf = ExprMutator;
563  using VisitBindingVTable =
564  tvm::NodeFunctor<void(const ObjectRef& n, ExprMutator* self, const VarBindingNode* binding)>;
565  // initialize the vtable.
566  static VisitBindingVTable InitVisitBindingVTable();
567 };
568 
569 } // namespace relax
570 } // namespace tvm
571 #endif // TVM_RELAX_EXPR_FUNCTOR_H_
The utility for constructing Relax binding blocks.
Global variable that lives in the top-level module.
Definition: expr.h:455
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Primitive Op(builtin intrinsics)
Definition: op.h:59
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference to RelaxExprNode.
Definition: expr.h:441
Definition: source_map.h:111
Definition: expr.h:660
Definition: expr.h:677
Definition: expr.h:585
Definition: block_builder.h:264
static BlockBuilder Create(ffi::Optional< IRModule > ctx_mod)
Create a BlockBuilder.
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:141
Constant tensor.
Definition: expr.h:425
Represent a data type constant.
Definition: expr.h:537
Definition: expr.h:685
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:396
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:141
virtual R VisitExpr_(const DataTypeImmNode *op, Args... args)
Definition: expr_functor.h:155
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:130
virtual R VisitExpr_(const ShapeExprNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const SeqExprNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:156
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:151
R result_type
the result type of this functor
Definition: expr_functor.h:114
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const ExternFuncNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const DataflowVarNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:154
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:123
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const PrimValueNode *op, Args... args)
Definition: expr_functor.h:153
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:152
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:140
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:116
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:52
A mutator works in unnormalized form.
Definition: expr_functor.h:323
Expr VisitExpr_(const TupleNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr_(const PrimValueNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block)
Mutate BindingBlock.
Expr VisitExpr_(const ShapeExprNode *op) override
virtual PrimExpr VisitPrimExpr(const PrimExpr &expr)
Used to visit the PrimExpr inside of expressions.
Expr VisitExpr_(const ExternFuncNode *op) override
Expr VisitExpr_(const DataTypeImmNode *op) override
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef &struct_info)
Check whether VisitExprDepStructInfoField change struct_info.
Definition: expr_functor.h:383
Expr VisitExpr_(const CallNode *op) override
Expr VisitExpr_(const OpNode *op) override
virtual StructInfo VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info that may recursively contain Expr/PrimExpr.
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const StringImmNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const FunctionNode *op) override
Expr VisitExpr_(const DataflowVarNode *op) override
A mutator works in normal form.
Definition: expr_functor.h:423
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
Expr VisitExpr_(const VarNode *op) override
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode *block)
void ReEmitBinding(const VarBindingNode *binding, Expr new_value)
Try to remit binding and bind it to a new_value.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
Var WithStructInfo(Var var, StructInfo struct_info)
Create a new var with specified struct_info if the original var's shape or type does not match with t...
Expr VisitWithNewScope(const Expr &body_expr, ffi::Optional< ffi::Array< Var >> params=std::nullopt)
Rewrite the expr with a new scope, used in a Function's body.
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode *block)
ffi::Optional< Expr > LookupBinding(const Var &var)
Look up the value bound to a variable.
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block) override
Generic dispatcher for binding blocks.
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
virtual Var VisitVarDef(const Var &var)
Generic dispatcher for rewriting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
virtual Var VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
Expr VisitExprPostOrder_(const T *op)
Post-order rewrite a node and normalize.
Definition: expr_functor.h:542
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
Expr VisitExpr_(const DataflowVarNode *op) override
Expr VisitWithInnerScope(const Expr &body_expr)
Rewrite the expr with a new scope, used in the branches of If.
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
BlockBuilder builder_
Internal block builder to emit bindings during rewriting.
Definition: expr_functor.h:556
virtual Var VisitVarDef_(const VarNode *var)
ExprMutator(ffi::Optional< IRModule > mod=std::nullopt)
Definition: expr_functor.h:427
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
std::unordered_map< Id, Var, ObjectPtrHash, ObjectPtrEqual > var_remap_
Remap a var to a new var in use-site.
Definition: expr_functor.h:559
virtual void VisitBinding_(const VarBindingNode *binding)
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:191
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const StringImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
void VisitExpr_(const SeqExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const IfNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr(const Expr &expr) override
Generic dispatcher for Expr.
virtual void VisitBindingBlock_(const DataflowBlockNode *block)
virtual void VisitSpan(const Span &span)
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
void VisitExpr_(const OpNode *op) override
virtual void VisitVarDef(const Var &var)
Generic dispatcher for visiting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
void VisitExpr_(const PrimValueNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
void VisitExpr_(const DataTypeImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
void VisitExpr_(const FunctionNode *op) override
void VisitExpr_(const VarNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
void VisitExpr_(const ExternFuncNode *op) override
virtual void VisitBindingBlock(const BindingBlock &block)
Generic dispatcher for binding blocks.
void VisitExpr_(const DataflowVarNode *op) override
virtual void VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
virtual void VisitPrimExpr(const PrimExpr &expr)
virtual void VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
void VisitExpr_(const ShapeExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
void VisitExpr_(const TupleNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
void VisitExpr_(const ConstantNode *op) override
virtual void VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info may recursively contain Expr/PrimExpr.
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
virtual void VisitBindingBlock_(const BindingBlockNode *block)
The extern function, which can represent packed function.
Definition: expr.h:903
Structure information about function.
Definition: struct_info.h:263
A Relax function.
Definition: expr.h:808
Condition expression.
Definition: expr.h:751
Runtime-match the value to the struct info.
Definition: expr.h:605
PrimValue.
Definition: expr.h:465
A sequence of blocks followed by an expression.
Definition: expr.h:706
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:324
Represent a string literal constant.
Definition: expr.h:505
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:142
Base type of all structure information.
Definition: expr.h:108
A struct info visitor.
Definition: struct_info_functor.h:123
Managed reference to StructInfoNode.
Definition: expr.h:132
Get index-th field out of a tuple.
Definition: expr.h:263
Tuple container.
Definition: expr.h:210
Definition: expr.h:633
The variable class for all Relax bindings.
Definition: expr.h:344
Definition: expr.h:380
Defines the Functor data structures.
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:55
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:60
Functors and visitors for struct info.
TIR Function.