25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ 26 #define TVM_RELAY_EXPR_FUNCTOR_H_ 37 #include <unordered_map> 55 template <
typename FType>
59 #define EXPR_FUNCTOR_DEFAULT \ 60 { return VisitExprDefault_(op, std::forward<Args>(args)...); } 62 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ 63 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ 64 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ 67 template <
typename R,
typename... Args>
84 R
operator()(
const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...); }
92 ICHECK(n.
defined()) <<
"Found null pointer node while traversing AST. The previous pass may " 93 "have generated invalid data.";
94 static FType vtable = InitVTable();
95 return vtable(n,
this, std::forward<Args>(args)...);
114 LOG(FATAL) <<
"Do not have a default for " << op->
GetTypeKey();
120 static FType InitVTable() {
221 std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>
memo_;
253 virtual void VisitLeaf(
const Expr& expr);
258 virtual bool CheckVisited(
const Expr& expr);
283 virtual Expr DispatchVisitExpr(
const Expr& expr);
304 template <
typename T>
307 return Rewrite_(op, post);
310 virtual void VisitLeaf(
const Expr& expr);
311 virtual bool CheckVisited(
const Expr& expr);
314 #define RELAY_EXPR_REWRITER_DISPATCH(OP) \ 315 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \ 316 return self->Rewrite_(static_cast<const OP*>(n.get()), post); \ 319 #define EXPR_REWRITER_REWRITE_DEFAULT \ 355 static FType vtable = InitVTable();
356 return vtable(pre,
this, post);
378 static FType InitVTable() {
423 : node{node_}, children_expanded{children_expanded_} {};
425 bool children_expanded{
false};
446 template <
typename FCheckVisited,
typename FVisitLeaf,
typename FExpandExpr>
448 FExpandExpr fexpand_expr) {
449 std::deque<v_info>
stack;
450 auto fpush_to_stack = [&fcheck_visited, &
stack](
const Expr& expr) {
451 if (!fcheck_visited(expr)) {
452 stack.emplace_front(
v_info(expr));
456 fpush_to_stack(expr);
457 while (stack.size() > 0) {
458 v_info* front = &stack.front();
459 if (fcheck_visited(front->
node)) {
462 fvisit_leaf(front->
node);
467 for (
auto e : fexpand_expr(front->
node)) {
474 template <
typename FCheckVisited,
typename FVisitLeaf>
476 auto fexpand_expr = [](
const Expr& expr) {
477 std::vector<Expr> result;
479 for (
auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
480 result.push_back(*it);
482 result.push_back(op->op);
484 for (
auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
485 result.push_back(*it);
488 result.push_back(op->tuple);
496 std::function<
void(
const LetNode*)> post_visit);
500 #endif // TVM_RELAY_EXPR_FUNCTOR_H_ A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
virtual R VisitExpr_(const RefWriteNode *op, Args... args)
Definition: expr_functor.h:110
Match container node.
Definition: adt.h:268
A wrapper around ExprFunctor which functionally updates the AST.
Definition: expr_functor.h:184
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be able to re...
Definition: expr_functor.h:295
Expr VisitExpr_(const TupleGetItemNode *op) final
Definition: expr_functor.h:286
virtual void VisitClause(const Clause &c)
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:104
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:91
Call container.
Definition: expr.h:229
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Definition: expr_functor.h:362
virtual R VisitExpr_(const RefCreateNode *op, Args... args)
Definition: expr_functor.h:108
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Relay expression language.
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:99
virtual Expr Rewrite_(const MatchNode *pre, const Expr &post)
Definition: expr_functor.h:374
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Expr PostOrderRewrite(const Expr &expr, ExprRewriter *rewriter)
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:113
#define EXPR_REWRITER_REWRITE_DEFAULT
Definition: expr_functor.h:319
R result_type
the result type of this functor
Definition: expr_functor.h:75
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:296
virtual Expr Rewrite_(const IfNode *pre, const Expr &post)
Definition: expr_functor.h:366
virtual void VisitType(const Type &t)
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:368
Constant tensor type.
Definition: expr.h:61
virtual Expr Rewrite_(const GlobalVarNode *pre, const Expr &post)
Definition: expr_functor.h:360
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
std::unordered_map< const Object *, size_t > visit_counter_
Definition: expr_functor.h:174
Primitive Op(builtin intrinsics)
Definition: op.h:58
size_t visit_limit_
The max number of times to visit a node.
Definition: expr_functor.h:262
Expr VisitExpr_(const CallNode *call_node) final
Definition: expr_functor.h:285
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:105
void VisitExpr(const Expr &expr) override
base class of all object containers.
Definition: object.h:165
bool children_expanded
Definition: expr_functor.h:425
Expr VisitExpr_(const TupleNode *op) final
Definition: expr_functor.h:284
Container for Var.
Definition: expr.h:157
virtual Expr Rewrite_(const ConstructorNode *pre, const Expr &post)
Definition: expr_functor.h:373
MixedModeMutator(bool pre=false)
Definition: expr_functor.h:280
v_info(Expr node_, bool children_expanded_)
Definition: expr_functor.h:422
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:100
virtual R VisitExpr_(const RefReadNode *op, Args... args)
Definition: expr_functor.h:109
Utilities for error tracking and reporting.
virtual void VisitSpan(const Span &span)
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:102
Algebraic data types for Relay.
Expr VisitExpr_(const VarNode *op) override
virtual Expr Rewrite_(const RefWriteNode *pre, const Expr &post)
Definition: expr_functor.h:372
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, FExpandExpr fexpand_expr)
A function to iteratively traverse dataflow regions of a graph.
Definition: expr_functor.h:447
A binding of a sub-network.
Definition: expr.h:335
void ExpandANormalForm(const LetNode *op, std::function< void(const LetNode *)> pre_visit, std::function< void(const LetNode *)> post_visit)
Definition: expr_functor.h:68
bool defined() const
Definition: object.h:537
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:443
virtual R VisitExpr_(const ConstructorNode *op, Args... args)
Definition: expr_functor.h:111
Relay Function container.
Definition: function.h:39
virtual void VisitPattern(const Pattern &c)
Tuple container.
Definition: expr.h:103
A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
Definition: expr_functor.h:233
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:84
virtual Expr Rewrite_(const OpNode *pre, const Expr &post)
Definition: expr_functor.h:367
Managed reference to RelayExprNode.
Definition: expr.h:177
bool pre_
Definition: expr_functor.h:300
v_info(Expr node_)
Definition: expr_functor.h:421
A struct to keep info of traversed expr in ExpandDataflow function.
Definition: expr_functor.h:420
Expr node
Definition: expr_functor.h:424
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:106
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
virtual Expr Rewrite_(const RefReadNode *pre, const Expr &post)
Definition: expr_functor.h:371
Defines the Functor data structures.
virtual Expr Rewrite_(const LetNode *pre, const Expr &post)
Definition: expr_functor.h:365
Base class of all object reference.
Definition: object.h:504
std::string GetTypeKey() const
Definition: object.h:178
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:107
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:103
Expr Rewrite(const T *op)
Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a post node with changed inputs...
Definition: expr_functor.h:305
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:101
virtual Expr Rewrite_(const FunctionNode *pre, const Expr &post)
Definition: expr_functor.h:363
A non-iterating Expression Rewriter.
Definition: expr_functor.h:332
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:77
#define RELAY_EXPR_REWRITER_DISPATCH(OP)
Definition: expr_functor.h:314
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Definition: expr_functor.h:278
virtual Expr Rewrite_(const ConstantNode *pre, const Expr &post)
Definition: expr_functor.h:361
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
virtual Expr Rewrite_(const RefCreateNode *pre, const Expr &post)
Definition: expr_functor.h:370
Expr operator()(const Expr &pre, const Expr &post)
Same as call.
Definition: expr_functor.h:346
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:149
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:297
Managed reference to TypeNode.
Definition: type.h:93
virtual Expr Rewrite_(const VarNode *pre, const Expr &post)
Definition: expr_functor.h:359
Expr Mutate(const Expr &expr)
Mutate is alias for VisitExpr.
Definition: expr_functor.h:190
std::unordered_map< Expr, Expr, ObjectPtrHash, ObjectPtrEqual > memo_
Internal map used for memoization.
Definition: expr_functor.h:221
Global variable that lives in the top-level module.
Definition: expr.h:191
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:98
virtual Expr Rewrite(const Expr &pre, const Expr &post)
The functor call.
Definition: expr_functor.h:353
void VisitExpr_(const VarNode *op) override
Primitive operators(builtin intrinsics).
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:364
virtual ~ExprRewriter()
virtual destructor
Definition: expr_functor.h:339
virtual R VisitExpr_(const MatchNode *op, Args... args)
Definition: expr_functor.h:112
container of If
Definition: expr.h:396