tvm
pattern_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
26 #define TVM_RELAY_PATTERN_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
29 #include <tvm/relay/error.h>
30 
31 #include <string>
32 #include <unordered_map>
33 #include <utility>
34 
35 #include "./adt.h"
36 #include "./expr.h"
37 #include "./op.h"
38 
39 namespace tvm {
40 namespace relay {
41 
53 template <typename FType>
55 
56 // functions to be overriden.
57 #define PATTERN_FUNCTOR_DEFAULT \
58  { return VisitPatternDefault_(op, std::forward<Args>(args)...); }
59 
60 #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
61  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
62  return self->VisitPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
63  });
64 
65 template <typename R, typename... Args>
66 class PatternFunctor<R(const Pattern& n, Args...)> {
67  private:
68  using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
69  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
70 
71  public:
73  using result_type = R;
75  virtual ~PatternFunctor() {}
82  R operator()(const Pattern& n, Args... args) {
83  return VisitPattern(n, std::forward<Args>(args)...);
84  }
91  virtual R VisitPattern(const Pattern& n, Args... args) {
92  ICHECK(n.defined());
93  static FType vtable = InitVTable();
94  return vtable(n, this, std::forward<Args>(args)...);
95  }
96  // Functions that can be overriden by subclass
97  virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
98  virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
99  virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
100  virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
101  virtual R VisitPatternDefault_(const Object* op, Args...) {
102  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
103  throw;
104  }
105 
106  private:
107  // initialize the vtable.
108  static FType InitVTable() {
109  FType vtable;
110  // Set dispatch
115  return vtable;
116  }
117 };
118 
125 class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> {
126  public:
127  void VisitPattern_(const PatternWildcardNode* op) override;
128  void VisitPattern_(const PatternVarNode* op) override;
129  void VisitPattern_(const PatternConstructorNode* op) override;
130  void VisitPattern_(const PatternTupleNode* op) override;
131  virtual void VisitType(const Type& t);
132  virtual void VisitVar(const Var& v);
133  virtual void VisitConstructor(const Constructor& c);
134 };
135 
141 class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
142  public:
143  Pattern Mutate(const Pattern& pat);
145  Pattern VisitPattern_(const PatternVarNode* op) override;
154  virtual Type VisitType(const Type& t);
156  virtual Var VisitVar(const Var& v);
159 
160  private:
161  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_map_;
162 };
163 
164 } // namespace relay
165 } // namespace tvm
166 #endif // TVM_RELAY_PATTERN_FUNCTOR_H_
Managed reference to ConstructorNode.
Definition: adt.h:88
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Managed reference to TypeNode.
Definition: type.h:93
PatternVar container node.
Definition: adt.h:150
virtual R VisitPattern_(const PatternVarNode *op, Args... args)
Definition: pattern_functor.h:98
virtual ~PatternFunctor()
virtual destructor
Definition: pattern_functor.h:75
R operator()(const Pattern &n, Args... args)
Same as call.
Definition: pattern_functor.h:82
virtual R VisitPattern_(const PatternWildcardNode *op, Args... args)
Definition: pattern_functor.h:97
virtual R VisitPattern_(const PatternTupleNode *op, Args... args)
Definition: pattern_functor.h:100
R result_type
the result type of this functor
Definition: pattern_functor.h:73
virtual R VisitPattern(const Pattern &n, Args... args)
The functor call.
Definition: pattern_functor.h:91
virtual R VisitPatternDefault_(const Object *op, Args...)
Definition: pattern_functor.h:101
virtual R VisitPattern_(const PatternConstructorNode *op, Args... args)
Definition: pattern_functor.h:99
A dynamical functor on ADT patterns that dispatches on its first argument. You can use this as a more...
Definition: pattern_functor.h:54
A wrapper around ExprFunctor which functionally updates the AST.
Definition: pattern_functor.h:141
virtual Constructor VisitConstructor(const Constructor &c)
Used to visit the vars inside of patterns.
virtual Type VisitType(const Type &t)
Used to visit the types inside of patterns.
Pattern VisitPattern_(const PatternVarNode *op) override
Pattern VisitPattern_(const PatternConstructorNode *op) override
virtual Var VisitVar(const Var &v)
Used to visit the vars inside of patterns.
Pattern VisitPattern_(const PatternWildcardNode *op) override
Pattern VisitPattern_(const PatternTupleNode *op) override
Pattern Mutate(const Pattern &pat)
PatternVar container node.
Definition: adt.h:191
PatternVar container node.
Definition: adt.h:116
A simple visitor wrapper around PatternFunctor.
Definition: pattern_functor.h:125
void VisitPattern_(const PatternTupleNode *op) override
void VisitPattern_(const PatternConstructorNode *op) override
virtual void VisitVar(const Var &v)
void VisitPattern_(const PatternWildcardNode *op) override
virtual void VisitType(const Type &t)
virtual void VisitConstructor(const Constructor &c)
void VisitPattern_(const PatternVarNode *op) override
PatternWildcard container node.
Definition: adt.h:74
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
Definition: expr.h:234
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Defines the Functor data structures.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP)
Definition: pattern_functor.h:60
#define PATTERN_FUNCTOR_DEFAULT
Definition: pattern_functor.h:57
Runtime ADT container types.
TIR expressions.
Common operators defined for Expr.