25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
26 #define TVM_RELAY_EXPR_FUNCTOR_H_
37 #include <unordered_map>
55 template <
typename FType>
59 #define EXPR_FUNCTOR_DEFAULT \
60 { return VisitExprDefault_(op, std::forward<Args>(args)...); }
62 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
63 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
64 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
67 template <
typename R,
typename... Args>
84 R
operator()(
const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...); }
92 ICHECK(n.
defined()) <<
"Found null pointer node while traversing AST. The previous pass may "
93 "have generated invalid data.";
94 static FType vtable = InitVTable();
95 return vtable(n,
this, std::forward<Args>(args)...);
114 LOG(FATAL) <<
"Do not have a default for " << op->
GetTypeKey();
120 static FType InitVTable() {
221 std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>
memo_;
310 template <
typename T>
320 #define RELAY_EXPR_REWRITER_DISPATCH(OP) \
321 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
322 return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
325 #define EXPR_REWRITER_REWRITE_DEFAULT \
361 static FType vtable = InitVTable();
362 return vtable(pre,
this, post);
384 static FType InitVTable() {
452 template <
typename FCheckVisited,
typename FVisitLeaf,
typename FExpandExpr>
454 FExpandExpr fexpand_expr) {
455 std::deque<v_info>
stack;
456 auto fpush_to_stack = [&fcheck_visited, &
stack](
const Expr& expr) {
457 if (!fcheck_visited(expr)) {
462 fpush_to_stack(expr);
463 while (
stack.size() > 0) {
465 if (fcheck_visited(front->
node)) {
468 fvisit_leaf(front->
node);
473 for (
auto e : fexpand_expr(front->
node)) {
480 template <
typename FCheckVisited,
typename FVisitLeaf>
482 auto fexpand_expr = [](
const Expr& expr) {
483 std::vector<Expr> result;
485 if (op->op ==
Op::Get(
"call_lowered")) {
487 const auto* tuple_args = op->args[1].as<
TupleNode>();
489 <<
"Expected second arg to call_lowered to be a Tuple of input arguments.";
490 for (
auto it = tuple_args->fields.rbegin(); it != tuple_args->fields.rend(); ++it) {
491 result.push_back(*it);
493 result.push_back(op->args[0]);
495 for (
auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
496 result.push_back(*it);
499 result.push_back(op->op);
501 for (
auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
502 result.push_back(*it);
505 result.push_back(op->tuple);
513 std::function<
void(
const LetNode*)> post_visit);
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Global variable that lives in the top-level module.
Definition: expr.h:456
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Primitive Op(builtin intrinsics)
Definition: op.h:58
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
Managed reference to RelayExprNode.
Definition: expr.h:442
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Call container.
Definition: expr.h:282
Constant tensor type.
Definition: expr.h:71
Definition: expr_functor.h:68
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:107
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:98
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:106
R result_type
the result type of this functor
Definition: expr_functor.h:75
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:101
virtual R VisitExpr_(const RefReadNode *op, Args... args)
Definition: expr_functor.h:109
virtual R VisitExpr_(const ConstructorNode *op, Args... args)
Definition: expr_functor.h:111
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:84
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:104
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:99
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:103
virtual R VisitExpr_(const RefWriteNode *op, Args... args)
Definition: expr_functor.h:110
virtual R VisitExpr_(const RefCreateNode *op, Args... args)
Definition: expr_functor.h:108
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:105
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:91
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:113
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:77
virtual R VisitExpr_(const MatchNode *op, Args... args)
Definition: expr_functor.h:112
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:100
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:102
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
A wrapper around ExprFunctor which functionally updates the AST.
Definition: expr_functor.h:184
Expr VisitExpr_(const RefWriteNode *op) override
Expr VisitExpr_(const OpNode *op) override
Expr VisitExpr_(const TupleGetItemNode *op) override
Expr VisitExpr_(const RefReadNode *op) override
virtual Pattern VisitPattern(const Pattern &c)
Expr VisitExpr_(const TupleNode *op) override
std::unordered_map< Expr, Expr, ObjectPtrHash, ObjectPtrEqual > memo_
Internal map used for memoization.
Definition: expr_functor.h:221
Expr VisitExpr_(const IfNode *op) override
Expr VisitExpr_(const MatchNode *op) override
Expr VisitExpr_(const ConstantNode *op) override
virtual Type VisitType(const Type &t)
Used to visit the types inside of expressions.
Expr Mutate(const Expr &expr)
Mutate is alias for VisitExpr.
Definition: expr_functor.h:190
Expr VisitExpr_(const RefCreateNode *op) override
Expr VisitExpr_(const ConstructorNode *op) override
Expr VisitExpr_(const VarNode *op) override
Expr VisitExpr_(const FunctionNode *op) override
virtual Clause VisitClause(const Clause &c)
Expr VisitExpr_(const LetNode *op) override
Expr VisitExpr_(const GlobalVarNode *op) override
Expr VisitExpr(const Expr &expr) override
Expr VisitExpr_(const CallNode *call_node) override
A non-iterating Expression Rewriter.
Definition: expr_functor.h:338
virtual Expr Rewrite_(const MatchNode *pre, const Expr &post)
Definition: expr_functor.h:380
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Definition: expr_functor.h:368
virtual Expr Rewrite(const Expr &pre, const Expr &post)
The functor call.
Definition: expr_functor.h:359
virtual Expr Rewrite_(const RefReadNode *pre, const Expr &post)
Definition: expr_functor.h:377
virtual Expr Rewrite_(const LetNode *pre, const Expr &post)
Definition: expr_functor.h:371
virtual Expr Rewrite_(const IfNode *pre, const Expr &post)
Definition: expr_functor.h:372
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:374
virtual Expr Rewrite_(const RefCreateNode *pre, const Expr &post)
Definition: expr_functor.h:376
virtual Expr Rewrite_(const RefWriteNode *pre, const Expr &post)
Definition: expr_functor.h:378
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:370
virtual Expr Rewrite_(const VarNode *pre, const Expr &post)
Definition: expr_functor.h:365
virtual Expr Rewrite_(const OpNode *pre, const Expr &post)
Definition: expr_functor.h:373
virtual Expr Rewrite_(const FunctionNode *pre, const Expr &post)
Definition: expr_functor.h:369
Expr operator()(const Expr &pre, const Expr &post)
Same as call.
Definition: expr_functor.h:352
virtual Expr Rewrite_(const ConstantNode *pre, const Expr &post)
Definition: expr_functor.h:367
virtual ~ExprRewriter()
virtual destructor
Definition: expr_functor.h:345
virtual Expr Rewrite_(const GlobalVarNode *pre, const Expr &post)
Definition: expr_functor.h:366
virtual Expr Rewrite_(const ConstructorNode *pre, const Expr &post)
Definition: expr_functor.h:379
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:149
void VisitExpr_(const GlobalVarNode *op) override
void VisitExpr_(const ConstantNode *op) override
void VisitExpr_(const FunctionNode *op) override
std::unordered_map< const Object *, size_t > visit_counter_
Definition: expr_functor.h:174
void VisitExpr_(const IfNode *op) override
virtual void VisitSpan(const Span &span)
void VisitExpr_(const TupleNode *op) override
void VisitExpr_(const VarNode *op) override
void VisitExpr(const Expr &expr) override
void VisitExpr_(const RefCreateNode *op) override
void VisitExpr_(const TupleGetItemNode *op) override
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const RefReadNode *op) override
virtual void VisitPattern(const Pattern &c)
virtual void VisitClause(const Clause &c)
void VisitExpr_(const ConstructorNode *op) override
void VisitExpr_(const MatchNode *op) override
void VisitExpr_(const OpNode *op) override
void VisitExpr_(const RefWriteNode *op) override
virtual void VisitType(const Type &t)
void VisitExpr_(const LetNode *op) override
Relay Function container.
Definition: function.h:39
container of If
Definition: expr.h:491
A binding of a sub-network.
Definition: expr.h:404
Match container node.
Definition: adt.h:277
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Definition: expr_functor.h:282
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:303
Expr VisitExpr(const Expr &expr) final
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be able to re...
Definition: expr_functor.h:301
Expr Rewrite(const T *op)
Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a post node with changed inputs.
Definition: expr_functor.h:311
Expr VisitExpr_(const TupleGetItemNode *op) final
Definition: expr_functor.h:292
bool pre_
Definition: expr_functor.h:306
Expr VisitExpr_(const TupleNode *op) final
Definition: expr_functor.h:290
Expr VisitExpr_(const CallNode *call_node) final
Definition: expr_functor.h:291
virtual Expr DispatchVisitExpr(const Expr &expr)
MixedModeMutator(bool pre=false)
Definition: expr_functor.h:286
virtual void VisitLeaf(const Expr &expr)
virtual bool CheckVisited(const Expr &expr)
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:302
A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
Definition: expr_functor.h:233
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const VarNode *op) override
void VisitExpr_(const TupleGetItemNode *op) override
virtual bool CheckVisited(const Expr &expr)
A function to determine if an expression has already been visited or needs to be re-visited.
void VisitExpr_(const TupleNode *op) override
size_t visit_limit_
The max number of times to visit a node.
Definition: expr_functor.h:266
virtual void VisitLeaf(const Expr &expr)
A function to apply when reaching a leaf of the graph non-recursively.
MixedModeVisitor(int visit_limit=1)
The constructor of MixedModeVisitor.
void VisitExpr(const Expr &expr) final
VisitExpr is finalized to preserve call expansion of dataflow regions.
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
Tuple container.
Definition: expr.h:123
Container for Var.
Definition: expr.h:188
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Defines the Functor data structures.
void ExpandANormalForm(const LetNode *op, std::function< void(const LetNode *)> pre_visit, std::function< void(const LetNode *)> post_visit)
tvm::RelayExpr Expr
Definition: expr.h:54
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, FExpandExpr fexpand_expr)
A function to iteratively traverse dataflow regions of a graph.
Definition: expr_functor.h:453
Expr PostOrderRewrite(const Expr &expr, ExprRewriter *rewriter)
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:532
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Algebraic data types for Relay.
Relay expression language.
#define RELAY_EXPR_REWRITER_DISPATCH(OP)
Definition: expr_functor.h:320
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
#define EXPR_REWRITER_REWRITE_DEFAULT
Definition: expr_functor.h:325
Primitive operators(builtin intrinsics).
A struct to keep info of traversed expr in ExpandDataflow function.
Definition: expr_functor.h:426
v_info(Expr node_)
Definition: expr_functor.h:427
Expr node
Definition: expr_functor.h:430
bool children_expanded
Definition: expr_functor.h:431
v_info(Expr node_, bool children_expanded_)
Definition: expr_functor.h:428