tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_EXPR_FUNCTOR_H_
26 #define TVM_RELAX_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/relax/struct_info.h>
33 #include <tvm/tir/function.h>
34 
35 #include <unordered_map>
36 #include <utility>
37 namespace tvm {
38 namespace relax {
39 
51 template <typename FType>
53 
54 // functions to be overriden.
55 #define EXPR_FUNCTOR_DEFAULT \
56  { return VisitExprDefault_(op, std::forward<Args>(args)...); }
57 
58 #define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
59  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
60  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
61  });
62 
63 #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \
64  { \
65  if (PY_FUNC != nullptr) \
66  PY_FUNC(N); \
67  else \
68  DEFAULT_FUNC; \
69  }
70 
71 #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \
72  { \
73  if (PY_FUNC != nullptr) { \
74  RET_TYPE ret = PY_FUNC(N); \
75  return ret; \
76  } else { \
77  return DEFAULT_FUNC; \
78  } \
79  }
80 
81 #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
82  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
83  if (self->PY_FUNC != nullptr) \
84  self->PY_FUNC(n); \
85  else \
86  self->VisitExpr_(static_cast<const OP*>(n.get())); \
87  });
88 
89 #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
90  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
91  if (self->PY_FUNC != nullptr) { \
92  Expr expr = self->PY_FUNC(n); \
93  return expr; \
94  } else { \
95  return self->VisitExpr_(static_cast<const OP*>(n.get())); \
96  } \
97  });
98 
99 #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \
100  post_order_vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
101  return self->VisitExprPostOrder_(static_cast<const OP*>(n.get())); \
102  });
103 
104 template <typename R, typename... Args>
105 class ExprFunctor<R(const Expr& n, Args...)> {
106  private:
107  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
108  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
109 
110  public:
112  using result_type = R;
114  virtual ~ExprFunctor() {}
121  R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); }
128  virtual R VisitExpr(const Expr& n, Args... args) {
129  ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
130  "have generated invalid data.";
131  static FType vtable = InitVTable();
132  return vtable(n, this, std::forward<Args>(args)...);
133  }
134  // Functions that can be overriden by subclass
135  // NOTE: cross dialect calls are invoked through global var
136  // We do not expect inline PrimFunc to appear in relax IR.
137  virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138  virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
139  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
140  virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141  virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142  virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143  virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144  virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146  virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147  virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148  virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150  virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152  virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153  virtual R VisitExprDefault_(const Object* op, Args...) {
154  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
155  throw;
156  }
157 
158  private:
159  // initialize the vtable.
160  static FType InitVTable() {
161  FType vtable;
162  // Set dispatch
179  vtable.Finalize();
180  return vtable;
181  }
182 };
183 
188 class ExprVisitor : public ExprFunctor<void(const Expr&)> {
189  public:
194  void VisitExpr(const Expr& expr) override;
195  // specific leaf level visitor functions
196  void VisitExpr_(const ConstantNode* op) override;
197  void VisitExpr_(const TupleNode* op) override;
198  void VisitExpr_(const VarNode* op) override;
199  void VisitExpr_(const DataflowVarNode* op) override;
200  void VisitExpr_(const ShapeExprNode* op) override;
201  void VisitExpr_(const ExternFuncNode* op) override;
202  void VisitExpr_(const GlobalVarNode* op) override;
203  void VisitExpr_(const FunctionNode* op) override;
204  void VisitExpr_(const CallNode* op) override;
205  void VisitExpr_(const SeqExprNode* op) override;
206  void VisitExpr_(const IfNode* op) override;
207  void VisitExpr_(const OpNode* op) override;
208  void VisitExpr_(const TupleGetItemNode* op) override;
209  void VisitExpr_(const PrimValueNode* op) override;
210  void VisitExpr_(const StringImmNode* op) override;
211  void VisitExpr_(const DataTypeImmNode* op) override;
212 
217  virtual void VisitBinding(const Binding& binding);
218  // specific leaf level visitor functions
219  virtual void VisitBinding_(const VarBindingNode* binding);
220  virtual void VisitBinding_(const MatchCastNode* binding);
221  // second level dispatching based on binding value type.
222  // these dispatching functions get called from first-level dispatch on VarBinding
223  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
224  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
225  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
226  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
227  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
228  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
229  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
230  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
231  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
232  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
233  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
234  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
235  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
236  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
237  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
238  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
243  virtual void VisitBindingBlock(const BindingBlock& block);
244  // specific leaf level visitor functions
245  virtual void VisitBindingBlock_(const BindingBlockNode* block);
246  virtual void VisitBindingBlock_(const DataflowBlockNode* block);
247 
253  virtual void VisitVarDef(const Var& var);
254 
269  virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);
270 
271  // specific leaf level visitor functions
272  virtual void VisitVarDef_(const VarNode* var);
273  virtual void VisitVarDef_(const DataflowVarNode* var);
274 
275  virtual void VisitSpan(const Span& span);
276  virtual void VisitPrimExpr(const PrimExpr& expr);
277 
278  private:
279  using TSelf = ExprVisitor;
280  using VisitBindingVTable =
281  tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
282  // initialize the vtable.
283  static VisitBindingVTable InitVisitBindingVTable();
293  class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
294  public:
295  explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);
296 
297  // Override defaults in struct info visitor.
298  void VisitStructInfoExprField(const Expr& expr) final;
299  void VisitStructInfoExprField(const PrimExpr& expr) final;
300  void VisitStructInfo_(const FuncStructInfoNode* op) final;
301 
302  private:
303  ExprVisitor* parent_;
304  };
305  // This visitor is not visible to child classes and only
306  // used to supported default visiting behavior.
307  DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
308 };
309 
310 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
311 
320 class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
321  public:
322  Expr VisitExpr(const Expr& expr) override;
323  Expr VisitExpr_(const ConstantNode* op) override;
324  Expr VisitExpr_(const TupleNode* op) override;
325  Expr VisitExpr_(const VarNode* op) override;
326  Expr VisitExpr_(const DataflowVarNode* op) override;
327  Expr VisitExpr_(const ShapeExprNode* op) override;
328  Expr VisitExpr_(const ExternFuncNode* op) override;
329  Expr VisitExpr_(const GlobalVarNode* op) override;
330  Expr VisitExpr_(const FunctionNode* op) override;
331  Expr VisitExpr_(const CallNode* op) override;
332  Expr VisitExpr_(const SeqExprNode* op) override;
333  Expr VisitExpr_(const IfNode* op) override;
334  Expr VisitExpr_(const OpNode* op) override;
335  Expr VisitExpr_(const TupleGetItemNode* op) override;
336  Expr VisitExpr_(const PrimValueNode* op) override;
337  Expr VisitExpr_(const StringImmNode* op) override;
338  Expr VisitExpr_(const DataTypeImmNode* op) override;
339 
346 
352  virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);
353 
369  virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);
370 
371  protected:
381  if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
382  return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
383  } else {
384  return true;
385  }
386  }
387 
388  private:
396  class DefaultStructInfoFieldMutator : public StructInfoMutator {
397  public:
398  explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);
399 
400  // Override defaults in struct info visitor.
401  Expr VisitStructInfoExprField(const Expr& expr) final;
402  PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
403  StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;
404 
405  private:
406  ExprMutatorBase* parent_;
407  };
408  // This visitor is not visible to child classes and only
409  // used to supported default visiting behavior.
410  DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
411 };
412 
420 class ExprMutator : public ExprMutatorBase {
421  public:
423 
425  Expr VisitExpr(const Expr& expr) override;
426  Expr VisitExpr_(const VarNode* op) override;
427  Expr VisitExpr_(const DataflowVarNode* op) override;
428  Expr VisitExpr_(const FunctionNode* op) override;
429  Expr VisitExpr_(const SeqExprNode* op) override;
430  Expr VisitExpr_(const IfNode* op) override;
431 
436  virtual void VisitBinding(const Binding& binding);
437  // specific leaf level visitor functions
438  virtual void VisitBinding_(const VarBindingNode* binding);
439  virtual void VisitBinding_(const MatchCastNode* binding);
440  // second level dispatching based on binding value type.
441  // these dispatching functions get called from first-level dispatch on VarBinding
442  virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
443  virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val);
444  virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val);
445  virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val);
446  virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val);
447  virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val);
448  virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val);
449  virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val);
450  virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val);
451  virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val);
452  virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val);
453  virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val);
454  virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val);
455  virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val);
456  virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val);
457  virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val);
463  virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*)
464  // specific leaf level visitor functions
467 
474  virtual Var VisitVarDef(const Var& var);
475  // specific leaf level visitor functions
476  virtual Var VisitVarDef_(const VarNode* var);
478 
479  protected:
491  void ReEmitBinding(const VarBindingNode* binding, Expr new_value);
492 
505  Expr VisitWithNewScope(const Expr& body_expr, Optional<Array<Var>> params = NullOpt);
506 
521  Expr VisitWithInnerScope(const Expr& body_expr);
522 
530 
537  template <typename T>
538  Expr VisitExprPostOrder_(const T* op) {
539  return builder_->Normalize(ExprMutator::VisitExpr_(op));
540  }
541 
550 
553 
555  std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
556 
557  private:
558  using TSelf = ExprMutator;
559  using VisitBindingVTable =
560  tvm::NodeFunctor<void(const ObjectRef& n, ExprMutator* self, const VarBindingNode* binding)>;
561  // initialize the vtable.
562  static VisitBindingVTable InitVisitBindingVTable();
563 };
564 
565 } // namespace relax
566 } // namespace tvm
567 #endif // TVM_RELAX_EXPR_FUNCTOR_H_
The utility for constructing Relax binding blocks.
Global variable that lives in the top-level module.
Definition: expr.h:419
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Primitive Op(builtin intrinsics)
Definition: op.h:58
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelaxExprNode.
Definition: expr.h:405
Definition: source_map.h:120
Definition: expr.h:763
Definition: expr.h:785
Definition: expr.h:682
Definition: block_builder.h:265
static BlockBuilder Create(Optional< IRModule > ctx_mod)
Create a BlockBuilder.
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:139
Constant tensor.
Definition: expr.h:481
Represent a data type constant.
Definition: expr.h:629
Definition: expr.h:793
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:438
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:138
virtual R VisitExpr_(const DataTypeImmNode *op, Args... args)
Definition: expr_functor.h:152
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:128
virtual R VisitExpr_(const ShapeExprNode *op, Args... args)
Definition: expr_functor.h:141
virtual R VisitExpr_(const SeqExprNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:153
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:148
R result_type
the result type of this functor
Definition: expr_functor.h:112
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExpr_(const ExternFuncNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const DataflowVarNode *op, Args... args)
Definition: expr_functor.h:140
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:151
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:121
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:139
virtual R VisitExpr_(const PrimValueNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:137
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:114
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:52
A mutator works in unnormalized form.
Definition: expr_functor.h:320
Expr VisitExpr_(const TupleNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr_(const PrimValueNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block)
Mutate BindingBlock.
Expr VisitExpr_(const ShapeExprNode *op) override
virtual PrimExpr VisitPrimExpr(const PrimExpr &expr)
Used to visit the PrimExpr inside of expressions.
Expr VisitExpr_(const ExternFuncNode *op) override
Expr VisitExpr_(const DataTypeImmNode *op) override
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef &struct_info)
Check whether VisitExprDepStructInfoField change struct_info.
Definition: expr_functor.h:380
Expr VisitExpr_(const CallNode *op) override
Expr VisitExpr_(const OpNode *op) override
virtual StructInfo VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info that may recursively contain Expr/PrimExpr.
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const StringImmNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const FunctionNode *op) override
Expr VisitExpr_(const DataflowVarNode *op) override
A mutator works in normal form.
Definition: expr_functor.h:420
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
Expr VisitExpr_(const VarNode *op) override
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode *block)
void ReEmitBinding(const VarBindingNode *binding, Expr new_value)
Try to remit binding and bind it to a new_value.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
Var WithStructInfo(Var var, StructInfo struct_info)
Create a new var with specified struct_info if the original var's shape or type does not match with t...
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode *block)
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block) override
Generic dispatcher for binding blocks.
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
virtual Var VisitVarDef(const Var &var)
Generic dispatcher for rewriting the var definition site.
ExprMutator(Optional< IRModule > mod=NullOpt)
Definition: expr_functor.h:424
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
virtual Var VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
Expr VisitExprPostOrder_(const T *op)
Post-order rewrite a node and normalize.
Definition: expr_functor.h:538
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
Optional< Expr > LookupBinding(const Var &var)
Look up the value bound to a variable.
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
Expr VisitExpr_(const DataflowVarNode *op) override
Expr VisitWithInnerScope(const Expr &body_expr)
Rewrite the expr with a new scope, used in the branches of If.
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
Expr VisitWithNewScope(const Expr &body_expr, Optional< Array< Var >> params=NullOpt)
Rewrite the expr with a new scope, used in a Function's body.
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
BlockBuilder builder_
Internal block builder to emit bindings during rewriting.
Definition: expr_functor.h:552
virtual Var VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
std::unordered_map< Id, Var, ObjectPtrHash, ObjectPtrEqual > var_remap_
Remap a var to a new var in use-site.
Definition: expr_functor.h:555
virtual void VisitBinding_(const VarBindingNode *binding)
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:188
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const StringImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
void VisitExpr_(const SeqExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const IfNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr(const Expr &expr) override
Generic dispatcher for Expr.
virtual void VisitBindingBlock_(const DataflowBlockNode *block)
virtual void VisitSpan(const Span &span)
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
void VisitExpr_(const OpNode *op) override
virtual void VisitVarDef(const Var &var)
Generic dispatcher for visiting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
void VisitExpr_(const PrimValueNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
void VisitExpr_(const DataTypeImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
void VisitExpr_(const FunctionNode *op) override
void VisitExpr_(const VarNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
void VisitExpr_(const ExternFuncNode *op) override
virtual void VisitBindingBlock(const BindingBlock &block)
Generic dispatcher for binding blocks.
void VisitExpr_(const DataflowVarNode *op) override
virtual void VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
virtual void VisitPrimExpr(const PrimExpr &expr)
virtual void VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
void VisitExpr_(const ShapeExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
void VisitExpr_(const TupleNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
void VisitExpr_(const ConstantNode *op) override
virtual void VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info may recursively contain Expr/PrimExpr.
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
virtual void VisitBindingBlock_(const BindingBlockNode *block)
The extern function, which can represent packed function.
Definition: expr.h:1066
Structure information about function.
Definition: struct_info.h:303
A Relax function.
Definition: expr.h:951
Condition expression.
Definition: expr.h:879
Runtime-match the value to the struct info.
Definition: expr.h:701
PrimValue.
Definition: expr.h:535
A sequence of blocks followed by an expression.
Definition: expr.h:818
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:357
Represent a string literal constant.
Definition: expr.h:586
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:140
Base type of all structure information.
Definition: expr.h:111
A struct info visitor.
Definition: struct_info_functor.h:121
Managed reference to StructInfoNode.
Definition: expr.h:130
Get index-th field out of a tuple.
Definition: expr.h:283
Tuple container.
Definition: expr.h:220
Definition: expr.h:736
The variable class for all Relax bindings.
Definition: expr.h:390
Definition: expr.h:423
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:520
bool defined() const
Definition: object.h:553
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:911
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:531
base class of all object containers.
Definition: object.h:172
std::string GetTypeKey() const
Definition: object.h:185
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:55
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:58
Functors and visitors for struct info.
TIR Function.