tvm
dataflow_pattern_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_
25 #define TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_
26 
28 
29 #include <unordered_set>
30 #include <utility>
31 
32 namespace tvm {
33 namespace relax {
34 
42 template <typename FType>
44 
45 // functions to be overriden.
46 #define DFPATTERN_FUNCTOR_DEFAULT \
47  { \
48  return VisitDFPatternDefault_(op, std::forward<Args>(args)...); \
49  }
50 
51 #define RELAX_DFPATTERN_FUNCTOR_DISPATCH(OP) \
52  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
53  return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
54  });
55 
56 template <typename R, typename... Args>
57 class DFPatternFunctor<R(const DFPattern& n, Args...)> {
58  private:
59  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
60  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
61 
62  public:
64  virtual ~DFPatternFunctor() {}
71  R operator()(const DFPattern& n, Args... args) {
72  return VisitDFPattern(n, std::forward<Args>(args)...);
73  }
80  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
81  TVM_FFI_ICHECK(n.defined());
82  static FType vtable = InitVTable();
83  return vtable(n, this, std::forward<Args>(args)...);
84  }
85  // Functions that can be overriden by subclass
86  virtual R VisitDFPattern_(const OrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
87  virtual R VisitDFPattern_(const AndPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
88  virtual R VisitDFPattern_(const NotPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
89  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
90  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
93  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
95  virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
97  Args... args) DFPATTERN_FUNCTOR_DEFAULT;
98  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
100  Args... args) DFPATTERN_FUNCTOR_DEFAULT;
102  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
103 
105  Args... args) DFPATTERN_FUNCTOR_DEFAULT;
108  Args... args) DFPATTERN_FUNCTOR_DEFAULT;
111  Args... args) DFPATTERN_FUNCTOR_DEFAULT;
112 
113  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
114  TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey();
115  throw;
116  }
117 
118  private:
119  // initialize the vtable.
120  static FType InitVTable() {
121  FType vtable;
122  // Set dispatch
143  vtable.Finalize();
144  return vtable;
145  }
146 };
147 
154 class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
155  public:
156  void VisitDFPattern(const DFPattern& pattern) override;
157  void VisitDFPattern_(const OrPatternNode* op) override;
158  void VisitDFPattern_(const AndPatternNode* op) override;
159  void VisitDFPattern_(const NotPatternNode* op) override;
160  void VisitDFPattern_(const AttrPatternNode* op) override;
161  void VisitDFPattern_(const CallPatternNode* op) override;
162  void VisitDFPattern_(const ConstantPatternNode* op) override;
163  void VisitDFPattern_(const DataTypePatternNode* op) override;
164  void VisitDFPattern_(const ExprPatternNode* op) override;
165  void VisitDFPattern_(const FunctionPatternNode* op) override;
166  void VisitDFPattern_(const ShapePatternNode* op) override;
167  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
168  void VisitDFPattern_(const TuplePatternNode* op) override;
169  void VisitDFPattern_(const StructInfoPatternNode* op) override;
170  void VisitDFPattern_(const WildcardPatternNode* op) override;
171  void VisitDFPattern_(const VarPatternNode* op) override;
172 
173  void VisitDFPattern_(const DataflowVarPatternNode* op) override;
174  void VisitDFPattern_(const GlobalVarPatternNode* op) override;
175  void VisitDFPattern_(const ExternFuncPatternNode* op) override;
176  void VisitDFPattern_(const PrimArrPatternNode* op) override;
177  void VisitDFPattern_(const UnorderedTuplePatternNode* op) override;
178 
179  protected:
180  // set of already-visited nodes
181  std::unordered_set<const Object*> visited_;
182 };
183 
184 } // namespace relax
185 } // namespace tvm
186 #endif // TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Match a conjunction of other patterns.
Definition: dataflow_pattern.h:659
A pattern that asserting a root pattern has certain attributes.
Definition: dataflow_pattern.h:888
A pattern to match a callable node in Relax.
Definition: dataflow_pattern.h:469
A Pattern to Match a Relax Constant.
Definition: dataflow_pattern.h:446
Definition: dataflow_pattern_functor.h:57
R operator()(const DFPattern &n, Args... args)
Same as call.
Definition: dataflow_pattern_functor.h:71
virtual R VisitDFPattern_(const ShapePatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:95
virtual R VisitDFPattern_(const PrimArrPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:109
virtual R VisitDFPattern_(const ConstantPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:91
virtual R VisitDFPattern_(const GlobalVarPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:106
virtual R VisitDFPattern_(const OrPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:86
virtual R VisitDFPattern_(const ExprPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:93
virtual R VisitDFPattern_(const UnorderedTuplePatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:110
virtual R VisitDFPattern_(const StructInfoPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:99
virtual R VisitDFPattern_(const ExternFuncPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:107
virtual R VisitDFPattern_(const VarPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:102
virtual R VisitDFPattern_(const NotPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:88
virtual R VisitDFPattern_(const DataflowVarPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:104
virtual R VisitDFPattern_(const WildcardPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:101
virtual R VisitDFPattern_(const TupleGetItemPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:96
virtual R VisitDFPatternDefault_(const Object *op, Args...)
Definition: dataflow_pattern_functor.h:113
virtual R VisitDFPattern_(const AttrPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:89
virtual R VisitDFPattern_(const AndPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:87
virtual R VisitDFPattern_(const CallPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:90
virtual R VisitDFPattern(const DFPattern &n, Args... args)
The functor call.
Definition: dataflow_pattern_functor.h:80
virtual R VisitDFPattern_(const FunctionPatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:94
virtual R VisitDFPattern_(const DataTypePatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:92
virtual ~DFPatternFunctor()
virtual destructor
Definition: dataflow_pattern_functor.h:64
virtual R VisitDFPattern_(const TuplePatternNode *op, Args... args)
Definition: dataflow_pattern_functor.h:98
A dynamical functor that dispatches on in the first DFPattern argument.
Definition: dataflow_pattern_functor.h:43
A simple visitor wrapper around DFPatternFunctor. Recursively visit the content.
Definition: dataflow_pattern_functor.h:154
void VisitDFPattern_(const NotPatternNode *op) override
void VisitDFPattern_(const ShapePatternNode *op) override
void VisitDFPattern_(const FunctionPatternNode *op) override
void VisitDFPattern_(const OrPatternNode *op) override
void VisitDFPattern_(const ExternFuncPatternNode *op) override
void VisitDFPattern_(const AndPatternNode *op) override
void VisitDFPattern(const DFPattern &pattern) override
void VisitDFPattern_(const TuplePatternNode *op) override
void VisitDFPattern_(const VarPatternNode *op) override
void VisitDFPattern_(const AttrPatternNode *op) override
void VisitDFPattern_(const ExprPatternNode *op) override
void VisitDFPattern_(const ConstantPatternNode *op) override
void VisitDFPattern_(const UnorderedTuplePatternNode *op) override
void VisitDFPattern_(const TupleGetItemPatternNode *op) override
void VisitDFPattern_(const DataTypePatternNode *op) override
void VisitDFPattern_(const WildcardPatternNode *op) override
void VisitDFPattern_(const GlobalVarPatternNode *op) override
void VisitDFPattern_(const PrimArrPatternNode *op) override
void VisitDFPattern_(const DataflowVarPatternNode *op) override
void VisitDFPattern_(const CallPatternNode *op) override
void VisitDFPattern_(const StructInfoPatternNode *op) override
std::unordered_set< const Object * > visited_
Definition: dataflow_pattern_functor.h:181
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
A pattern that asserting a root pattern has a certain data type.
Definition: dataflow_pattern.h:859
A Pattern to Match a Relax Dataflow Variable.
Definition: dataflow_pattern.h:401
Pattern for Relax Expression.
Definition: dataflow_pattern.h:342
A pattern of external function.
Definition: dataflow_pattern.h:917
A pattern to match a Relax Function.
Definition: dataflow_pattern.h:534
A Pattern to Match a Relax Global Variable.
Definition: dataflow_pattern.h:426
Pattern for rejecting a certain pattern.
Definition: dataflow_pattern.h:715
Match a disjunction of other patterns.
Definition: dataflow_pattern.h:687
A pattern to match an array of PrimExpr.
Definition: dataflow_pattern.h:508
A pattern that asserting a root pattern has a certain shape.
Definition: dataflow_pattern.h:799
Pattern for matching a certain struct info.
Definition: dataflow_pattern.h:774
A pattern to match n'th indexing to a tuple.
Definition: dataflow_pattern.h:629
Pattern to match a tuple of ordered expressions.
Definition: dataflow_pattern.h:575
A pattern to match multiple expressions unorderedly.
Definition: dataflow_pattern.h:600
A Pattern to Match a Relax Variable.
Definition: dataflow_pattern.h:368
Wildcard Pattern is a pattern that can match anything.
Definition: dataflow_pattern.h:740
A pattern language for matching dataflow properties.
#define RELAX_DFPATTERN_FUNCTOR_DISPATCH(OP)
Definition: dataflow_pattern_functor.h:51
#define DFPATTERN_FUNCTOR_DEFAULT
Definition: dataflow_pattern_functor.h:46
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37