25 #ifndef TVM_RELAX_EXPR_FUNCTOR_H_
26 #define TVM_RELAX_EXPR_FUNCTOR_H_
38 #include <unordered_map>
55 template <
typename FType>
59 #define EXPR_FUNCTOR_DEFAULT \
60 { return VisitExprDefault_(op, std::forward<Args>(args)...); }
62 #define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
63 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
64 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
67 #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \
69 if (PY_FUNC != nullptr) \
75 #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \
77 if (PY_FUNC != nullptr) { \
78 RET_TYPE ret = PY_FUNC(N); \
81 return DEFAULT_FUNC; \
85 #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
86 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
87 if (self->PY_FUNC != nullptr) \
90 self->VisitExpr_(static_cast<const OP*>(n.get())); \
93 #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
94 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
95 if (self->PY_FUNC != nullptr) { \
96 Expr expr = self->PY_FUNC(n); \
99 return self->VisitExpr_(static_cast<const OP*>(n.get())); \
103 #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \
104 post_order_vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
105 return self->VisitExprPostOrder_(static_cast<const OP*>(n.get())); \
108 template <
typename R,
typename... Args>
125 R
operator()(
const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...); }
133 ICHECK(n.
defined()) <<
"Found null pointer node while traversing AST. The previous pass may "
134 "have generated invalid data.";
135 static FType vtable = InitVTable();
136 return vtable(n,
this, std::forward<Args>(args)...);
158 LOG(FATAL) <<
"Do not have a default for " << op->
GetTypeKey();
164 static FType InitVTable() {
298 explicit DefaultStructInfoFieldVisitor(
ExprVisitor* parent);
301 void VisitStructInfoExprField(
const Expr& expr)
final;
302 void VisitStructInfoExprField(
const PrimExpr& expr)
final;
310 DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{
this};
404 Expr VisitStructInfoExprField(
const Expr& expr)
final;
413 DefaultStructInfoFieldMutator default_struct_info_field_mutator_{
this};
540 template <
typename T>
558 std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual>
var_remap_;
The utility for constructing Relax binding blocks.
Global variable that lives in the top-level module.
Definition: expr.h:456
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Primitive Op(builtin intrinsics)
Definition: op.h:58
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelayExprNode.
Definition: expr.h:442
Definition: source_map.h:120
Definition: block_builder.h:264
static BlockBuilder Create(Optional< IRModule > ctx_mod)
Create a BlockBuilder.
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:138
Constant tensor.
Definition: expr.h:480
Represent a data type constant.
Definition: expr.h:628
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:437
Definition: expr_functor.h:109
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const DataTypeImmNode *op, Args... args)
Definition: expr_functor.h:156
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:132
virtual R VisitExpr_(const ShapeExprNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const SeqExprNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:157
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:152
R result_type
the result type of this functor
Definition: expr_functor.h:116
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:151
virtual R VisitExpr_(const ExternFuncNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExpr_(const DataflowVarNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:155
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:125
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const PrimValueNode *op, Args... args)
Definition: expr_functor.h:154
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:153
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:141
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:118
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
A mutator works in unnormalized form.
Definition: expr_functor.h:323
Expr VisitExpr_(const TupleNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr_(const PrimValueNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block)
Mutate BindingBlock.
Expr VisitExpr_(const ShapeExprNode *op) override
virtual PrimExpr VisitPrimExpr(const PrimExpr &expr)
Used to visit the PrimExpr inside of expressions.
Expr VisitExpr_(const ExternFuncNode *op) override
Expr VisitExpr_(const DataTypeImmNode *op) override
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef &struct_info)
Check whether VisitExprDepStructInfoField change struct_info.
Definition: expr_functor.h:383
Expr VisitExpr_(const CallNode *op) override
Expr VisitExpr_(const OpNode *op) override
virtual StructInfo VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info that may recursively contain Expr/PrimExpr.
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const StringImmNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const FunctionNode *op) override
Expr VisitExpr_(const DataflowVarNode *op) override
A mutator works in normal form.
Definition: expr_functor.h:423
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
Expr VisitExpr_(const VarNode *op) override
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode *block)
void ReEmitBinding(const VarBindingNode *binding, Expr new_value)
Try to remit binding and bind it to a new_value.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
Var WithStructInfo(Var var, StructInfo struct_info)
Create a new var with specified struct_info if the original var's shape or type does not match with t...
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode *block)
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block) override
Generic dispatcher for binding blocks.
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
virtual Var VisitVarDef(const Var &var)
Generic dispatcher for rewriting the var definition site.
ExprMutator(Optional< IRModule > mod=NullOpt)
Definition: expr_functor.h:427
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
virtual Var VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
Expr VisitExprPostOrder_(const T *op)
Post-order rewrite a node and normalize.
Definition: expr_functor.h:541
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
Optional< Expr > LookupBinding(const Var &var)
Look up the value bound to a variable.
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
Expr VisitExpr_(const DataflowVarNode *op) override
Expr VisitWithInnerScope(const Expr &body_expr)
Rewrite the expr with a new scope, used in the branches of If.
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
Expr VisitWithNewScope(const Expr &body_expr, Optional< Array< Var >> params=NullOpt)
Rewrite the expr with a new scope, used in a Function's body.
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
BlockBuilder builder_
Internal block builder to emit bindings during rewriting.
Definition: expr_functor.h:555
virtual Var VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
std::unordered_map< Id, Var, ObjectPtrHash, ObjectPtrEqual > var_remap_
Remap a var to a new var in use-site.
Definition: expr_functor.h:558
virtual void VisitBinding_(const VarBindingNode *binding)
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:191
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const StringImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
void VisitExpr_(const SeqExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const IfNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr(const Expr &expr) override
Generic dispatcher for Expr.
virtual void VisitBindingBlock_(const DataflowBlockNode *block)
virtual void VisitSpan(const Span &span)
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
void VisitExpr_(const OpNode *op) override
virtual void VisitVarDef(const Var &var)
Generic dispatcher for visiting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
void VisitExpr_(const PrimValueNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
void VisitExpr_(const DataTypeImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
void VisitExpr_(const FunctionNode *op) override
void VisitExpr_(const VarNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
void VisitExpr_(const ExternFuncNode *op) override
virtual void VisitBindingBlock(const BindingBlock &block)
Generic dispatcher for binding blocks.
void VisitExpr_(const DataflowVarNode *op) override
virtual void VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
virtual void VisitPrimExpr(const PrimExpr &expr)
virtual void VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
void VisitExpr_(const ShapeExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
void VisitExpr_(const TupleNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
void VisitExpr_(const ConstantNode *op) override
virtual void VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info may recursively contain Expr/PrimExpr.
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
virtual void VisitBindingBlock_(const BindingBlockNode *block)
The extern function, which can represent packed function.
Definition: expr.h:1065
Structure information about function.
Definition: struct_info.h:303
A Relax function.
Definition: expr.h:950
Condition expression.
Definition: expr.h:878
Runtime-match the value to the struct info.
Definition: expr.h:700
PrimValue.
Definition: expr.h:534
A sequence of blocks followed by an expression.
Definition: expr.h:817
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:356
Represent a string literal constant.
Definition: expr.h:585
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:139
Base type of all structure information.
Definition: expr.h:110
A struct info visitor.
Definition: struct_info_functor.h:120
Managed reference to StructInfoNode.
Definition: expr.h:129
Get index-th field out of a tuple.
Definition: expr.h:282
Tuple container.
Definition: expr.h:219
The variable class for all Relax bindings.
Definition: expr.h:389
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
Primitive operators(builtin intrinsics).
Functors and visitors for struct info.