25 #ifndef TVM_RELAX_EXPR_FUNCTOR_H_
26 #define TVM_RELAX_EXPR_FUNCTOR_H_
35 #include <unordered_map>
51 template <
typename FType>
55 #define EXPR_FUNCTOR_DEFAULT \
57 return VisitExprDefault_(op, std::forward<Args>(args)...); \
60 #define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \
61 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
62 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
65 #define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \
67 if (PY_FUNC != nullptr) \
73 #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \
75 if (PY_FUNC != nullptr) { \
76 RET_TYPE ret = PY_FUNC(N).cast<RET_TYPE>(); \
79 return DEFAULT_FUNC; \
83 #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
84 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
85 if (self->PY_FUNC != nullptr) \
88 self->VisitExpr_(static_cast<const OP*>(n.get())); \
91 #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
92 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
93 if (self->PY_FUNC != nullptr) { \
94 Expr expr = self->PY_FUNC(n).cast<Expr>(); \
97 return self->VisitExpr_(static_cast<const OP*>(n.get())); \
101 #define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \
102 post_order_vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
103 return self->VisitExprPostOrder_(static_cast<const OP*>(n.get())); \
106 template <
typename R,
typename... Args>
123 R
operator()(
const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...); }
131 TVM_FFI_ICHECK(n.defined())
132 <<
"Found null pointer node while traversing AST. The previous pass may "
133 "have generated invalid data.";
134 static FType vtable = InitVTable();
135 return vtable(n,
this, std::forward<Args>(args)...);
157 TVM_FFI_THROW(InternalError) <<
"Do not have a default for " << op->GetTypeKey();
163 static FType InitVTable() {
298 explicit DefaultStructInfoFieldVisitor(
ExprVisitor* parent);
301 void VisitStructInfoExprField(
const Expr& expr)
final;
302 void VisitStructInfoExprField(
const PrimExpr& expr)
final;
310 DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{
this};
404 Expr VisitStructInfoExprField(
const Expr& expr)
final;
413 DefaultStructInfoFieldMutator default_struct_info_field_mutator_{
this};
509 ffi::Optional<ffi::Array<Var>> params = std::nullopt);
541 template <
typename T>
559 std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual>
var_remap_;
The utility for constructing Relax binding blocks.
Global variable that lives in the top-level module.
Definition: expr.h:455
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Primitive Op(builtin intrinsics)
Definition: op.h:59
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference to RelaxExprNode.
Definition: expr.h:441
Definition: source_map.h:111
Definition: block_builder.h:264
static BlockBuilder Create(ffi::Optional< IRModule > ctx_mod)
Create a BlockBuilder.
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:141
Constant tensor.
Definition: expr.h:425
Represent a data type constant.
Definition: expr.h:537
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:396
Definition: expr_functor.h:107
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:141
virtual R VisitExpr_(const DataTypeImmNode *op, Args... args)
Definition: expr_functor.h:155
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:130
virtual R VisitExpr_(const ShapeExprNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const SeqExprNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:156
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:151
R result_type
the result type of this functor
Definition: expr_functor.h:114
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const ExternFuncNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const DataflowVarNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:154
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:123
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const PrimValueNode *op, Args... args)
Definition: expr_functor.h:153
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:152
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:140
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:116
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:52
A mutator works in unnormalized form.
Definition: expr_functor.h:323
Expr VisitExpr_(const TupleNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr_(const PrimValueNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block)
Mutate BindingBlock.
Expr VisitExpr_(const ShapeExprNode *op) override
virtual PrimExpr VisitPrimExpr(const PrimExpr &expr)
Used to visit the PrimExpr inside of expressions.
Expr VisitExpr_(const ExternFuncNode *op) override
Expr VisitExpr_(const DataTypeImmNode *op) override
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef &struct_info)
Check whether VisitExprDepStructInfoField change struct_info.
Definition: expr_functor.h:383
Expr VisitExpr_(const CallNode *op) override
Expr VisitExpr_(const OpNode *op) override
virtual StructInfo VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info that may recursively contain Expr/PrimExpr.
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const StringImmNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const FunctionNode *op) override
Expr VisitExpr_(const DataflowVarNode *op) override
A mutator works in normal form.
Definition: expr_functor.h:423
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
Expr VisitExpr_(const VarNode *op) override
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode *block)
void ReEmitBinding(const VarBindingNode *binding, Expr new_value)
Try to remit binding and bind it to a new_value.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
Expr VisitExpr_(const ConstantNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
Var WithStructInfo(Var var, StructInfo struct_info)
Create a new var with specified struct_info if the original var's shape or type does not match with t...
Expr VisitWithNewScope(const Expr &body_expr, ffi::Optional< ffi::Array< Var >> params=std::nullopt)
Rewrite the expr with a new scope, used in a Function's body.
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode *block)
ffi::Optional< Expr > LookupBinding(const Var &var)
Look up the value bound to a variable.
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const SeqExprNode *op) override
virtual BindingBlock VisitBindingBlock(const BindingBlock &block) override
Generic dispatcher for binding blocks.
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
virtual Var VisitVarDef(const Var &var)
Generic dispatcher for rewriting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
virtual Var VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
Expr VisitExprPostOrder_(const T *op)
Post-order rewrite a node and normalize.
Definition: expr_functor.h:542
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
Expr VisitExpr_(const DataflowVarNode *op) override
Expr VisitWithInnerScope(const Expr &body_expr)
Rewrite the expr with a new scope, used in the branches of If.
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
BlockBuilder builder_
Internal block builder to emit bindings during rewriting.
Definition: expr_functor.h:556
virtual Var VisitVarDef_(const VarNode *var)
ExprMutator(ffi::Optional< IRModule > mod=std::nullopt)
Definition: expr_functor.h:427
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
std::unordered_map< Id, Var, ObjectPtrHash, ObjectPtrEqual > var_remap_
Remap a var to a new var in use-site.
Definition: expr_functor.h:559
virtual void VisitBinding_(const VarBindingNode *binding)
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:191
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const StringImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const OpNode *val)
void VisitExpr_(const SeqExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleNode *val)
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const IfNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const GlobalVarNode *val)
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr(const Expr &expr) override
Generic dispatcher for Expr.
virtual void VisitBindingBlock_(const DataflowBlockNode *block)
virtual void VisitSpan(const Span &span)
virtual void VisitBinding_(const VarBindingNode *binding, const StringImmNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ShapeExprNode *val)
void VisitExpr_(const OpNode *op) override
virtual void VisitVarDef(const Var &var)
Generic dispatcher for visiting the var definition site.
virtual void VisitBinding_(const VarBindingNode *binding, const PrimValueNode *val)
void VisitExpr_(const PrimValueNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const IfNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const CallNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const DataflowVarNode *val)
virtual void VisitBinding_(const MatchCastNode *binding)
void VisitExpr_(const DataTypeImmNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const TupleGetItemNode *val)
void VisitExpr_(const FunctionNode *op) override
void VisitExpr_(const VarNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const VarNode *val)
virtual void VisitBinding_(const VarBindingNode *binding, const ConstantNode *val)
void VisitExpr_(const ExternFuncNode *op) override
virtual void VisitBindingBlock(const BindingBlock &block)
Generic dispatcher for binding blocks.
void VisitExpr_(const DataflowVarNode *op) override
virtual void VisitVarDef_(const DataflowVarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const DataTypeImmNode *val)
virtual void VisitPrimExpr(const PrimExpr &expr)
virtual void VisitVarDef_(const VarNode *var)
virtual void VisitBinding_(const VarBindingNode *binding, const SeqExprNode *val)
void VisitExpr_(const ShapeExprNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding, const ExternFuncNode *val)
void VisitExpr_(const TupleNode *op) override
virtual void VisitBinding_(const VarBindingNode *binding)
virtual void VisitBinding(const Binding &binding)
Generic dispatcher for bindings.
void VisitExpr_(const ConstantNode *op) override
virtual void VisitExprDepStructInfoField(const StructInfo &struct_info)
Visit struct_info may recursively contain Expr/PrimExpr.
virtual void VisitBinding_(const VarBindingNode *binding, const FunctionNode *val)
virtual void VisitBindingBlock_(const BindingBlockNode *block)
The extern function, which can represent packed function.
Definition: expr.h:903
Structure information about function.
Definition: struct_info.h:263
A Relax function.
Definition: expr.h:808
Condition expression.
Definition: expr.h:751
Runtime-match the value to the struct info.
Definition: expr.h:605
PrimValue.
Definition: expr.h:465
A sequence of blocks followed by an expression.
Definition: expr.h:706
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:324
Represent a string literal constant.
Definition: expr.h:505
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:142
Base type of all structure information.
Definition: expr.h:108
A struct info visitor.
Definition: struct_info_functor.h:123
Managed reference to StructInfoNode.
Definition: expr.h:132
Get index-th field out of a tuple.
Definition: expr.h:263
Tuple container.
Definition: expr.h:210
The variable class for all Relax bindings.
Definition: expr.h:344
Defines the Functor data structures.
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:55
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:60
Functors and visitors for struct info.