tvm
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_TIR_EXPR_FUNCTOR_H_
26 #define TVM_TIR_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
29 #include <tvm/tir/expr.h>
30 
31 #include <utility>
32 
33 namespace tvm {
34 namespace tir {
35 
74 template <typename FType>
76 
77 // functions to be overriden.
78 #define EXPR_FUNCTOR_DEFAULT \
79  { return VisitExprDefault_(op, std::forward<Args>(args)...); }
80 
81 #define IR_EXPR_FUNCTOR_DISPATCH(OP) \
82  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
83  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
84  });
85 
86 template <typename R, typename... Args>
87 class ExprFunctor<R(const PrimExpr& n, Args...)> {
88  private:
89  using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
90  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
91 
92  public:
94  using result_type = R;
96  virtual ~ExprFunctor() {}
103  R operator()(const PrimExpr& n, Args... args) {
104  return VisitExpr(n, std::forward<Args>(args)...);
105  }
112  virtual R VisitExpr(const PrimExpr& n, Args... args) {
113  static FType vtable = InitVTable();
114  return vtable(n, this, std::forward<Args>(args)...);
115  }
116  // Functions that can be overriden by subclass
117  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
118  virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
119  return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
120  }
121  virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122  virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123  virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
125  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
126  virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
127  virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
128  virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
129  virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
130  virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
131  virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
132  virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
133  virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
134  virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
135  virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
136  virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
137  virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138  virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
139  virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
140  virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141  virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142  virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143  virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144  virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145  virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146  virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147  virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148  virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149  virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150  virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151  virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153  virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
154  virtual R VisitExprDefault_(const Object* op, Args...) {
155  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
156  return R();
157  }
158 
159  private:
160  // initialize the vtable.
161  static FType InitVTable() {
162  FType vtable;
163  // Set dispatch
199  return vtable;
200  }
201 };
202 
203 #undef IR_EXPR_FUNCTOR_DISPATCH
204 #undef EXPR_FUNCTOR_DEFAULT
205 
209 class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
210  public:
211  using ExprFunctor::operator();
212 
213  protected:
214  using ExprFunctor::VisitExpr;
215  // list of functions to override.
216  void VisitExpr_(const VarNode* op) override;
217  void VisitExpr_(const SizeVarNode* op) override;
218  void VisitExpr_(const LoadNode* op) override;
219  void VisitExpr_(const BufferLoadNode* op) override;
220  void VisitExpr_(const ProducerLoadNode* op) override;
221  void VisitExpr_(const LetNode* op) override;
222  void VisitExpr_(const CallNode* op) override;
223  void VisitExpr_(const AddNode* op) override;
224  void VisitExpr_(const SubNode* op) override;
225  void VisitExpr_(const MulNode* op) override;
226  void VisitExpr_(const DivNode* op) override;
227  void VisitExpr_(const ModNode* op) override;
228  void VisitExpr_(const FloorDivNode* op) override;
229  void VisitExpr_(const FloorModNode* op) override;
230  void VisitExpr_(const MinNode* op) override;
231  void VisitExpr_(const MaxNode* op) override;
232  void VisitExpr_(const EQNode* op) override;
233  void VisitExpr_(const NENode* op) override;
234  void VisitExpr_(const LTNode* op) override;
235  void VisitExpr_(const LENode* op) override;
236  void VisitExpr_(const GTNode* op) override;
237  void VisitExpr_(const GENode* op) override;
238  void VisitExpr_(const AndNode* op) override;
239  void VisitExpr_(const OrNode* op) override;
240  void VisitExpr_(const ReduceNode* op) override;
241  void VisitExpr_(const CastNode* op) override;
242  void VisitExpr_(const NotNode* op) override;
243  void VisitExpr_(const SelectNode* op) override;
244  void VisitExpr_(const RampNode* op) override;
245  void VisitExpr_(const BroadcastNode* op) override;
246  void VisitExpr_(const ShuffleNode* op) override;
247  void VisitExpr_(const IntImmNode* op) override;
248  void VisitExpr_(const FloatImmNode* op) override;
249  void VisitExpr_(const StringImmNode* op) override;
250  void VisitExpr_(const AnyNode* op) override;
251 };
252 
256 class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
257  public:
258  using ExprFunctor::operator();
259 
260  protected:
261  using ExprFunctor::VisitExpr;
262  // list of functions to override.
263  PrimExpr VisitExpr_(const VarNode* op) override;
264  PrimExpr VisitExpr_(const SizeVarNode* op) override;
265  PrimExpr VisitExpr_(const LoadNode* op) override;
266  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
267  PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
268  PrimExpr VisitExpr_(const LetNode* op) override;
269  PrimExpr VisitExpr_(const CallNode* op) override;
270  PrimExpr VisitExpr_(const AddNode* op) override;
271  PrimExpr VisitExpr_(const SubNode* op) override;
272  PrimExpr VisitExpr_(const MulNode* op) override;
273  PrimExpr VisitExpr_(const DivNode* op) override;
274  PrimExpr VisitExpr_(const ModNode* op) override;
275  PrimExpr VisitExpr_(const FloorDivNode* op) override;
276  PrimExpr VisitExpr_(const FloorModNode* op) override;
277  PrimExpr VisitExpr_(const MinNode* op) override;
278  PrimExpr VisitExpr_(const MaxNode* op) override;
279  PrimExpr VisitExpr_(const EQNode* op) override;
280  PrimExpr VisitExpr_(const NENode* op) override;
281  PrimExpr VisitExpr_(const LTNode* op) override;
282  PrimExpr VisitExpr_(const LENode* op) override;
283  PrimExpr VisitExpr_(const GTNode* op) override;
284  PrimExpr VisitExpr_(const GENode* op) override;
285  PrimExpr VisitExpr_(const AndNode* op) override;
286  PrimExpr VisitExpr_(const OrNode* op) override;
287  PrimExpr VisitExpr_(const ReduceNode* op) override;
288  PrimExpr VisitExpr_(const CastNode* op) override;
289  PrimExpr VisitExpr_(const NotNode* op) override;
290  PrimExpr VisitExpr_(const SelectNode* op) override;
291  PrimExpr VisitExpr_(const RampNode* op) override;
292  PrimExpr VisitExpr_(const BroadcastNode* op) override;
293  PrimExpr VisitExpr_(const ShuffleNode* op) override;
294  PrimExpr VisitExpr_(const IntImmNode* op) override;
295  PrimExpr VisitExpr_(const FloatImmNode* op) override;
296  PrimExpr VisitExpr_(const StringImmNode* op) override;
297  PrimExpr VisitExpr_(const AnyNode* op) override;
298 };
299 
300 } // namespace tir
301 } // namespace tvm
302 #endif // TVM_TIR_EXPR_FUNCTOR_H_
virtual R VisitExpr_(const MinNode *op, Args... args)
Definition: expr_functor.h:133
Let binding. Bind var to value then evaluate body.
Definition: expr.h:864
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
virtual R VisitExpr_(const IntImmNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const OrNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const GENode *op, Args... args)
Definition: expr_functor.h:140
virtual R VisitExpr_(const BroadcastNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExpr_(const NotNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const ShuffleNode *op, Args... args)
Definition: expr_functor.h:149
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
ExprVisitor.
Definition: expr_functor.h:209
virtual R VisitExpr_(const FloorModNode *op, Args... args)
Definition: expr_functor.h:132
virtual R VisitExpr_(const ReduceNode *op, Args... args)
Definition: expr_functor.h:143
String constants, only used in asserts.
Definition: expr.h:53
Constant floating point literals in the program.
Definition: expr.h:535
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:124
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:96
virtual R VisitExpr_(const FloatImmNode *op, Args... args)
Definition: expr_functor.h:151
A variable node in the IR.
Definition: var.h:47
a * b
Definition: expr.h:187
base class of all object containers.
Definition: object.h:167
Any shape.
Definition: expr.h:1130
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...
Definition: expr.h:958
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:137
virtual R VisitExpr_(const LENode *op, Args... args)
Definition: expr_functor.h:138
virtual R VisitExpr_(const AddNode *op, Args... args)
Definition: expr_functor.h:126
a + b
Definition: expr.h:155
Constant integer literals in the program.
Definition: expr.h:489
virtual R VisitExpr_(const SubNode *op, Args... args)
Definition: expr_functor.h:127
a || b
Definition: expr.h:472
virtual R VisitExpr_(const CastNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const ModNode *op, Args... args)
Definition: expr_functor.h:130
TIR expressions.
virtual R VisitExpr_(const LoadNode *op, Args... args)
Definition: expr_functor.h:123
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:117
virtual R VisitExpr_(const DivNode *op, Args... args)
Definition: expr_functor.h:129
virtual R VisitExpr_(const LTNode *op, Args... args)
Definition: expr_functor.h:137
#define IR_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:81
Create a vector where all the elements are value.
Definition: expr.h:823
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:125
a > b
Definition: expr.h:401
Cast value from one data type to another.
Definition: expr.h:88
virtual R VisitExpr_(const EQNode *op, Args... args)
Definition: expr_functor.h:135
virtual R VisitExpr_(const SelectNode *op, Args... args)
Definition: expr_functor.h:146
virtual R VisitExpr_(const SizeVarNode *op, Args... args)
Definition: expr_functor.h:118
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:154
virtual R VisitExpr_(const AnyNode *op, Args... args)
Definition: expr_functor.h:153
Load the value from buffer_var.
Definition: expr.h:726
Defines the Functor data structures.
virtual R VisitExpr_(const GTNode *op, Args... args)
Definition: expr_functor.h:139
Base class of all object reference.
Definition: object.h:511
!a
Definition: expr.h:511
max(a, b)
Definition: expr.h:289
std::string GetTypeKey() const
Definition: object.h:180
R result_type
the result type of this functor
Definition: expr_functor.h:94
virtual R VisitExpr_(const MulNode *op, Args... args)
Definition: expr_functor.h:128
Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load.
Definition: expr.h:779
virtual R VisitExpr_(const AndNode *op, Args... args)
Definition: expr_functor.h:141
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:78
virtual R VisitExpr_(const RampNode *op, Args... args)
Definition: expr_functor.h:147
min(a, b)
Definition: expr.h:273
The remainder of the floordiv.
Definition: expr.h:257
a == b
Definition: expr.h:337
virtual R VisitExpr(const PrimExpr &n, Args... args)
The functor call.
Definition: expr_functor.h:112
a && b
Definition: expr.h:433
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:152
a < b
Definition: expr.h:369
Load value from the high dimension buffer.
Definition: expr.h:606
virtual R VisitExpr_(const ProducerLoadNode *op, Args... args)
Definition: expr_functor.h:122
a % b in the C semnatics.
Definition: expr.h:225
Reference to PrimExprNode.
Definition: expr.h:112
virtual R VisitExpr_(const FloorDivNode *op, Args... args)
Definition: expr_functor.h:131
Floor division, floor(a/b)
Definition: expr.h:241
a - b
Definition: expr.h:171
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:75
R operator()(const PrimExpr &n, Args... args)
Same as call.
Definition: expr_functor.h:103
Call node.
Definition: expr.h:910
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:552
Reduction operator operator.
Definition: expr.h:1066
a / b in the C semnatics.
Definition: expr.h:206
a <= b
Definition: expr.h:385
Load value from the result produced by the producer.
Definition: expr.h:671
virtual R VisitExpr_(const MaxNode *op, Args... args)
Definition: expr_functor.h:134
virtual R VisitExpr_(const NENode *op, Args... args)
Definition: expr_functor.h:136
a != b
Definition: expr.h:353
virtual R VisitExpr_(const BufferLoadNode *op, Args... args)
Definition: expr_functor.h:121
ExprMutator that mutates expressions.
Definition: expr_functor.h:256
a >= b
Definition: expr.h:417