25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ 26 #define TVM_RELAY_EXPR_FUNCTOR_H_ 37 #include <unordered_map> 55 template <
typename FType>
59 #define EXPR_FUNCTOR_DEFAULT \ 60 { return VisitExprDefault_(op, std::forward<Args>(args)...); } 62 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ 63 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ 64 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ 67 template <
typename R,
typename... Args>
84 R
operator()(
const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...); }
92 ICHECK(n.
defined()) <<
"Found null pointer node while traversing AST. The previous pass may " 93 "have generated invalid data.";
94 static FType vtable = InitVTable();
95 return vtable(n,
this, std::forward<Args>(args)...);
114 LOG(FATAL) <<
"Do not have a default for " << op->
GetTypeKey();
120 static FType InitVTable() {
221 std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>
memo_;
235 using ::tvm::relay::ExprFunctor<void(const Expr& n)>::VisitExpr_;
257 virtual void VisitLeaf(
const Expr& expr);
262 virtual bool CheckVisited(
const Expr& expr);
284 using ::tvm::relay::ExprFunctor<Expr(const Expr&)>::VisitExpr_;
289 virtual Expr DispatchVisitExpr(
const Expr& expr);
310 template <
typename T>
313 return Rewrite_(op, post);
316 virtual void VisitLeaf(
const Expr& expr);
317 virtual bool CheckVisited(
const Expr& expr);
320 #define RELAY_EXPR_REWRITER_DISPATCH(OP) \ 321 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \ 322 return self->Rewrite_(static_cast<const OP*>(n.get()), post); \ 325 #define EXPR_REWRITER_REWRITE_DEFAULT \ 361 static FType vtable = InitVTable();
362 return vtable(pre,
this, post);
384 static FType InitVTable() {
429 : node{node_}, children_expanded{children_expanded_} {};
431 bool children_expanded{
false};
452 template <
typename FCheckVisited,
typename FVisitLeaf,
typename FExpandExpr>
454 FExpandExpr fexpand_expr) {
455 std::deque<v_info>
stack;
456 auto fpush_to_stack = [&fcheck_visited, &
stack](
const Expr& expr) {
457 if (!fcheck_visited(expr)) {
458 stack.emplace_front(
v_info(expr));
462 fpush_to_stack(expr);
463 while (stack.size() > 0) {
464 v_info* front = &stack.front();
465 if (fcheck_visited(front->
node)) {
468 fvisit_leaf(front->
node);
473 for (
auto e : fexpand_expr(front->
node)) {
480 template <
typename FCheckVisited,
typename FVisitLeaf>
482 auto fexpand_expr = [](
const Expr& expr) {
483 std::vector<Expr> result;
485 if (op->op ==
Op::Get(
"call_lowered")) {
487 const auto* tuple_args = op->args[1].as<
TupleNode>();
489 <<
"Expected second arg to call_lowered to be a Tuple of input arguments.";
490 for (
auto it = tuple_args->fields.rbegin(); it != tuple_args->fields.rend(); ++it) {
491 result.push_back(*it);
493 result.push_back(op->args[0]);
495 for (
auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
496 result.push_back(*it);
499 result.push_back(op->op);
501 for (
auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
502 result.push_back(*it);
505 result.push_back(op->tuple);
513 std::function<
void(
const LetNode*)> post_visit);
517 #endif // TVM_RELAY_EXPR_FUNCTOR_H_ A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
virtual R VisitExpr_(const RefWriteNode *op, Args... args)
Definition: expr_functor.h:110
Match container node.
Definition: adt.h:277
A wrapper around ExprFunctor which functionally updates the AST.
Definition: expr_functor.h:184
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be able to re...
Definition: expr_functor.h:301
Expr VisitExpr_(const TupleGetItemNode *op) final
Definition: expr_functor.h:292
virtual void VisitClause(const Clause &c)
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:62
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:104
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:59
virtual R VisitExpr(const Expr &n, Args... args)
The functor call.
Definition: expr_functor.h:91
Call container.
Definition: expr.h:282
virtual Expr Rewrite_(const TupleNode *pre, const Expr &post)
Definition: expr_functor.h:368
virtual R VisitExpr_(const RefCreateNode *op, Args... args)
Definition: expr_functor.h:108
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Relay expression language.
virtual R VisitExpr_(const TupleNode *op, Args... args)
Definition: expr_functor.h:99
virtual Expr Rewrite_(const MatchNode *pre, const Expr &post)
Definition: expr_functor.h:380
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Expr PostOrderRewrite(const Expr &expr, ExprRewriter *rewriter)
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:113
#define EXPR_REWRITER_REWRITE_DEFAULT
Definition: expr_functor.h:325
R result_type
the result type of this functor
Definition: expr_functor.h:75
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:302
virtual Expr Rewrite_(const IfNode *pre, const Expr &post)
Definition: expr_functor.h:372
virtual void VisitType(const Type &t)
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:374
Constant tensor type.
Definition: expr.h:71
virtual Expr Rewrite_(const GlobalVarNode *pre, const Expr &post)
Definition: expr_functor.h:366
void PostOrderVisit(const Expr &node, std::function< void(const Expr &)> fvisit)
recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
std::unordered_map< const Object *, size_t > visit_counter_
Definition: expr_functor.h:174
Primitive Op(builtin intrinsics)
Definition: op.h:58
size_t visit_limit_
The max number of times to visit a node.
Definition: expr_functor.h:266
Expr VisitExpr_(const CallNode *call_node) final
Definition: expr_functor.h:291
virtual R VisitExpr_(const IfNode *op, Args... args)
Definition: expr_functor.h:105
void VisitExpr(const Expr &expr) override
base class of all object containers.
Definition: object.h:167
bool children_expanded
Definition: expr_functor.h:431
Expr VisitExpr_(const TupleNode *op) final
Definition: expr_functor.h:290
Container for Var.
Definition: expr.h:188
virtual Expr Rewrite_(const ConstructorNode *pre, const Expr &post)
Definition: expr_functor.h:379
MixedModeMutator(bool pre=false)
Definition: expr_functor.h:286
v_info(Expr node_, bool children_expanded_)
Definition: expr_functor.h:428
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:100
virtual R VisitExpr_(const RefReadNode *op, Args... args)
Definition: expr_functor.h:109
virtual void VisitSpan(const Span &span)
virtual R VisitExpr_(const FunctionNode *op, Args... args)
Definition: expr_functor.h:102
Algebraic data types for Relay.
Expr VisitExpr_(const VarNode *op) override
Definition: source_map.h:120
virtual Expr Rewrite_(const RefWriteNode *pre, const Expr &post)
Definition: expr_functor.h:378
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, FExpandExpr fexpand_expr)
A function to iteratively traverse dataflow regions of a graph.
Definition: expr_functor.h:453
A binding of a sub-network.
Definition: expr.h:404
void ExpandANormalForm(const LetNode *op, std::function< void(const LetNode *)> pre_visit, std::function< void(const LetNode *)> post_visit)
Definition: expr_functor.h:68
bool defined() const
Definition: object.h:544
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:529
virtual R VisitExpr_(const ConstructorNode *op, Args... args)
Definition: expr_functor.h:111
Relay Function container.
Definition: function.h:39
virtual void VisitPattern(const Pattern &c)
Tuple container.
Definition: expr.h:123
A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
Definition: expr_functor.h:233
R operator()(const Expr &n, Args... args)
Same as call.
Definition: expr_functor.h:84
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
virtual Expr Rewrite_(const OpNode *pre, const Expr &post)
Definition: expr_functor.h:373
Managed reference to RelayExprNode.
Definition: expr.h:433
bool pre_
Definition: expr_functor.h:306
v_info(Expr node_)
Definition: expr_functor.h:427
A struct to keep info of traversed expr in ExpandDataflow function.
Definition: expr_functor.h:426
Expr node
Definition: expr_functor.h:430
virtual R VisitExpr_(const OpNode *op, Args... args)
Definition: expr_functor.h:106
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
virtual Expr Rewrite_(const RefReadNode *pre, const Expr &post)
Definition: expr_functor.h:377
Defines the Functor data structures.
virtual Expr Rewrite_(const LetNode *pre, const Expr &post)
Definition: expr_functor.h:371
Base class of all object reference.
Definition: object.h:511
std::string GetTypeKey() const
Definition: object.h:180
virtual R VisitExpr_(const TupleGetItemNode *op, Args... args)
Definition: expr_functor.h:107
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:103
Expr Rewrite(const T *op)
Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a post node with changed inputs...
Definition: expr_functor.h:311
virtual R VisitExpr_(const GlobalVarNode *op, Args... args)
Definition: expr_functor.h:101
virtual Expr Rewrite_(const FunctionNode *pre, const Expr &post)
Definition: expr_functor.h:369
A non-iterating Expression Rewriter.
Definition: expr_functor.h:338
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:77
#define RELAY_EXPR_REWRITER_DISPATCH(OP)
Definition: expr_functor.h:320
Non-recursive DFS Graph Traversal for Custom Rewriting Passes.
Definition: expr_functor.h:282
virtual Expr Rewrite_(const ConstantNode *pre, const Expr &post)
Definition: expr_functor.h:367
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:56
virtual Expr Rewrite_(const RefCreateNode *pre, const Expr &post)
Definition: expr_functor.h:376
Expr operator()(const Expr &pre, const Expr &post)
Same as call.
Definition: expr_functor.h:352
A simple visitor wrapper around ExprFunctor. Recursively visit the content.
Definition: expr_functor.h:149
virtual Expr Rewrite_(const TupleGetItemNode *pre, const Expr &post)
Definition: expr_functor.h:303
Managed reference to TypeNode.
Definition: type.h:93
virtual Expr Rewrite_(const VarNode *pre, const Expr &post)
Definition: expr_functor.h:365
Expr Mutate(const Expr &expr)
Mutate is alias for VisitExpr.
Definition: expr_functor.h:190
std::unordered_map< Expr, Expr, ObjectPtrHash, ObjectPtrEqual > memo_
Internal map used for memoization.
Definition: expr_functor.h:221
Global variable that lives in the top-level module.
Definition: expr.h:447
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
virtual R VisitExpr_(const ConstantNode *op, Args... args)
Definition: expr_functor.h:98
virtual Expr Rewrite(const Expr &pre, const Expr &post)
The functor call.
Definition: expr_functor.h:359
void VisitExpr_(const VarNode *op) override
Primitive operators(builtin intrinsics).
virtual Expr Rewrite_(const CallNode *pre, const Expr &post)
Definition: expr_functor.h:370
virtual ~ExprRewriter()
virtual destructor
Definition: expr_functor.h:345
virtual R VisitExpr_(const MatchNode *op, Args... args)
Definition: expr_functor.h:112
container of If
Definition: expr.h:491