tvm
expr_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_TIR_EXPR_FUNCTOR_H_
26 #define TVM_TIR_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
29 #include <tvm/tir/expr.h>
30 
31 #include <utility>
32 
33 namespace tvm {
34 namespace tir {
35 
74 template <typename FType>
76 
77 // functions to be overriden.
78 #define EXPR_FUNCTOR_DEFAULT \
79  { return VisitExprDefault_(op, std::forward<Args>(args)...); }
80 
81 #define IR_EXPR_FUNCTOR_DISPATCH(OP) \
82  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
83  return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
84  });
85 
86 template <typename R, typename... Args>
87 class ExprFunctor<R(const PrimExpr& n, Args...)> {
88  private:
89  using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
90  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
91 
92  public:
94  using result_type = R;
96  virtual ~ExprFunctor() {}
103  R operator()(const PrimExpr& n, Args... args) {
104  return VisitExpr(n, std::forward<Args>(args)...);
105  }
112  virtual R VisitExpr(const PrimExpr& n, Args... args) {
113  static FType vtable = InitVTable();
114  return vtable(n, this, std::forward<Args>(args)...);
115  }
116  // Functions that can be overriden by subclass
117  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
118  virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
119  return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
120  }
121  virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122  virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
125  virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
126  virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
127  virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
128  virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
129  virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
130  virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
131  virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
132  virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
133  virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
134  virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
135  virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
136  virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
137  virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138  virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
139  virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
140  virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141  virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142  virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143  virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144  virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145  virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146  virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147  virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148  virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149  virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150  virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152  virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153  virtual R VisitExprDefault_(const Object* op, Args...) {
154  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
155  }
156 
157  private:
158  // initialize the vtable.
159  static FType InitVTable() {
160  FType vtable;
161  // Set dispatch
196  return vtable;
197  }
198 };
199 
200 #undef IR_EXPR_FUNCTOR_DISPATCH
201 #undef EXPR_FUNCTOR_DEFAULT
202 
206 class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
207  public:
208  using ExprFunctor::operator();
209 
210  protected:
211  using ExprFunctor::VisitExpr;
212  // list of functions to override.
213  void VisitExpr_(const VarNode* op) override;
214  void VisitExpr_(const SizeVarNode* op) override;
215  void VisitExpr_(const BufferLoadNode* op) override;
216  void VisitExpr_(const ProducerLoadNode* op) override;
217  void VisitExpr_(const LetNode* op) override;
218  void VisitExpr_(const CallNode* op) override;
219  void VisitExpr_(const AddNode* op) override;
220  void VisitExpr_(const SubNode* op) override;
221  void VisitExpr_(const MulNode* op) override;
222  void VisitExpr_(const DivNode* op) override;
223  void VisitExpr_(const ModNode* op) override;
224  void VisitExpr_(const FloorDivNode* op) override;
225  void VisitExpr_(const FloorModNode* op) override;
226  void VisitExpr_(const MinNode* op) override;
227  void VisitExpr_(const MaxNode* op) override;
228  void VisitExpr_(const EQNode* op) override;
229  void VisitExpr_(const NENode* op) override;
230  void VisitExpr_(const LTNode* op) override;
231  void VisitExpr_(const LENode* op) override;
232  void VisitExpr_(const GTNode* op) override;
233  void VisitExpr_(const GENode* op) override;
234  void VisitExpr_(const AndNode* op) override;
235  void VisitExpr_(const OrNode* op) override;
236  void VisitExpr_(const ReduceNode* op) override;
237  void VisitExpr_(const CastNode* op) override;
238  void VisitExpr_(const NotNode* op) override;
239  void VisitExpr_(const SelectNode* op) override;
240  void VisitExpr_(const RampNode* op) override;
241  void VisitExpr_(const BroadcastNode* op) override;
242  void VisitExpr_(const ShuffleNode* op) override;
243  void VisitExpr_(const IntImmNode* op) override;
244  void VisitExpr_(const FloatImmNode* op) override;
245  void VisitExpr_(const StringImmNode* op) override;
246  void VisitExpr_(const AnyNode* op) override;
247 };
248 
252 class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
253  public:
254  using ExprFunctor::operator();
255 
256  protected:
257  using ExprFunctor::VisitExpr;
258  // list of functions to override.
259  PrimExpr VisitExpr_(const VarNode* op) override;
260  PrimExpr VisitExpr_(const SizeVarNode* op) override;
261  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
262  PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
263  PrimExpr VisitExpr_(const LetNode* op) override;
264  PrimExpr VisitExpr_(const CallNode* op) override;
265  PrimExpr VisitExpr_(const AddNode* op) override;
266  PrimExpr VisitExpr_(const SubNode* op) override;
267  PrimExpr VisitExpr_(const MulNode* op) override;
268  PrimExpr VisitExpr_(const DivNode* op) override;
269  PrimExpr VisitExpr_(const ModNode* op) override;
270  PrimExpr VisitExpr_(const FloorDivNode* op) override;
271  PrimExpr VisitExpr_(const FloorModNode* op) override;
272  PrimExpr VisitExpr_(const MinNode* op) override;
273  PrimExpr VisitExpr_(const MaxNode* op) override;
274  PrimExpr VisitExpr_(const EQNode* op) override;
275  PrimExpr VisitExpr_(const NENode* op) override;
276  PrimExpr VisitExpr_(const LTNode* op) override;
277  PrimExpr VisitExpr_(const LENode* op) override;
278  PrimExpr VisitExpr_(const GTNode* op) override;
279  PrimExpr VisitExpr_(const GENode* op) override;
280  PrimExpr VisitExpr_(const AndNode* op) override;
281  PrimExpr VisitExpr_(const OrNode* op) override;
282  PrimExpr VisitExpr_(const ReduceNode* op) override;
283  PrimExpr VisitExpr_(const CastNode* op) override;
284  PrimExpr VisitExpr_(const NotNode* op) override;
285  PrimExpr VisitExpr_(const SelectNode* op) override;
286  PrimExpr VisitExpr_(const RampNode* op) override;
287  PrimExpr VisitExpr_(const BroadcastNode* op) override;
288  PrimExpr VisitExpr_(const ShuffleNode* op) override;
289  PrimExpr VisitExpr_(const IntImmNode* op) override;
290  PrimExpr VisitExpr_(const FloatImmNode* op) override;
291  PrimExpr VisitExpr_(const StringImmNode* op) override;
292  PrimExpr VisitExpr_(const AnyNode* op) override;
293 };
294 
295 } // namespace tir
296 } // namespace tvm
297 #endif // TVM_TIR_EXPR_FUNCTOR_H_
Constant floating point literals in the program.
Definition: expr.h:548
Constant integer literals in the program.
Definition: expr.h:501
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Reference to PrimExprNode.
Definition: expr.h:115
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
a + b
Definition: expr.h:157
a && b
Definition: expr.h:450
Any shape.
Definition: expr.h:1104
Create a vector where all the elements are value.
Definition: expr.h:792
Load value from the high dimension buffer.
Definition: expr.h:627
Call node.
Definition: expr.h:881
Cast value from one data type to another.
Definition: expr.h:89
a / b in the C semnatics.
Definition: expr.h:211
a == b
Definition: expr.h:348
virtual R VisitExpr_(const CastNode *op, Args... args)
Definition: expr_functor.h:143
virtual R VisitExpr_(const BroadcastNode *op, Args... args)
Definition: expr_functor.h:147
virtual R VisitExpr_(const OrNode *op, Args... args)
Definition: expr_functor.h:141
virtual R VisitExpr_(const SelectNode *op, Args... args)
Definition: expr_functor.h:145
virtual R VisitExpr_(const VarNode *op, Args... args)
Definition: expr_functor.h:117
virtual R VisitExpr_(const GENode *op, Args... args)
Definition: expr_functor.h:139
virtual R VisitExpr_(const ProducerLoadNode *op, Args... args)
Definition: expr_functor.h:122
virtual R VisitExpr_(const MinNode *op, Args... args)
Definition: expr_functor.h:132
virtual R VisitExpr_(const AndNode *op, Args... args)
Definition: expr_functor.h:140
virtual R VisitExpr_(const LTNode *op, Args... args)
Definition: expr_functor.h:136
virtual R VisitExpr_(const LetNode *op, Args... args)
Definition: expr_functor.h:123
virtual R VisitExpr_(const RampNode *op, Args... args)
Definition: expr_functor.h:146
virtual ~ExprFunctor()
virtual destructor
Definition: expr_functor.h:96
virtual R VisitExpr_(const AnyNode *op, Args... args)
Definition: expr_functor.h:152
virtual R VisitExpr_(const EQNode *op, Args... args)
Definition: expr_functor.h:134
virtual R VisitExpr_(const MulNode *op, Args... args)
Definition: expr_functor.h:127
virtual R VisitExpr_(const ReduceNode *op, Args... args)
Definition: expr_functor.h:142
virtual R VisitExpr_(const DivNode *op, Args... args)
Definition: expr_functor.h:128
virtual R VisitExpr_(const BufferLoadNode *op, Args... args)
Definition: expr_functor.h:121
virtual R VisitExpr_(const IntImmNode *op, Args... args)
Definition: expr_functor.h:149
virtual R VisitExpr_(const CallNode *op, Args... args)
Definition: expr_functor.h:124
R operator()(const PrimExpr &n, Args... args)
Same as call.
Definition: expr_functor.h:103
virtual R VisitExpr_(const AddNode *op, Args... args)
Definition: expr_functor.h:125
virtual R VisitExpr_(const SubNode *op, Args... args)
Definition: expr_functor.h:126
virtual R VisitExpr_(const FloorDivNode *op, Args... args)
Definition: expr_functor.h:130
virtual R VisitExpr_(const GTNode *op, Args... args)
Definition: expr_functor.h:138
virtual R VisitExpr_(const LENode *op, Args... args)
Definition: expr_functor.h:137
R result_type
the result type of this functor
Definition: expr_functor.h:94
virtual R VisitExpr_(const ShuffleNode *op, Args... args)
Definition: expr_functor.h:148
virtual R VisitExpr_(const FloorModNode *op, Args... args)
Definition: expr_functor.h:131
virtual R VisitExprDefault_(const Object *op, Args...)
Definition: expr_functor.h:153
virtual R VisitExpr_(const NENode *op, Args... args)
Definition: expr_functor.h:135
virtual R VisitExpr_(const FloatImmNode *op, Args... args)
Definition: expr_functor.h:150
virtual R VisitExpr_(const ModNode *op, Args... args)
Definition: expr_functor.h:129
virtual R VisitExpr_(const NotNode *op, Args... args)
Definition: expr_functor.h:144
virtual R VisitExpr_(const SizeVarNode *op, Args... args)
Definition: expr_functor.h:118
virtual R VisitExpr_(const MaxNode *op, Args... args)
Definition: expr_functor.h:133
virtual R VisitExpr_(const StringImmNode *op, Args... args)
Definition: expr_functor.h:151
virtual R VisitExpr(const PrimExpr &n, Args... args)
The functor call.
Definition: expr_functor.h:112
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerfu...
Definition: expr_functor.h:75
ExprMutator that mutates expressions.
Definition: expr_functor.h:252
PrimExpr VisitExpr_(const GTNode *op) override
PrimExpr VisitExpr_(const FloorDivNode *op) override
PrimExpr VisitExpr_(const LENode *op) override
PrimExpr VisitExpr_(const CastNode *op) override
PrimExpr VisitExpr_(const BroadcastNode *op) override
PrimExpr VisitExpr_(const SelectNode *op) override
PrimExpr VisitExpr_(const MinNode *op) override
PrimExpr VisitExpr_(const IntImmNode *op) override
PrimExpr VisitExpr_(const RampNode *op) override
PrimExpr VisitExpr_(const ReduceNode *op) override
PrimExpr VisitExpr_(const OrNode *op) override
PrimExpr VisitExpr_(const FloorModNode *op) override
PrimExpr VisitExpr_(const DivNode *op) override
PrimExpr VisitExpr_(const AddNode *op) override
PrimExpr VisitExpr_(const ProducerLoadNode *op) override
PrimExpr VisitExpr_(const CallNode *op) override
PrimExpr VisitExpr_(const ShuffleNode *op) override
PrimExpr VisitExpr_(const SubNode *op) override
PrimExpr VisitExpr_(const FloatImmNode *op) override
PrimExpr VisitExpr_(const EQNode *op) override
PrimExpr VisitExpr_(const NENode *op) override
PrimExpr VisitExpr_(const GENode *op) override
PrimExpr VisitExpr_(const MulNode *op) override
PrimExpr VisitExpr_(const LTNode *op) override
PrimExpr VisitExpr_(const StringImmNode *op) override
PrimExpr VisitExpr_(const ModNode *op) override
PrimExpr VisitExpr_(const AndNode *op) override
PrimExpr VisitExpr_(const BufferLoadNode *op) override
PrimExpr VisitExpr_(const SizeVarNode *op) override
PrimExpr VisitExpr_(const LetNode *op) override
PrimExpr VisitExpr_(const MaxNode *op) override
PrimExpr VisitExpr_(const VarNode *op) override
PrimExpr VisitExpr_(const AnyNode *op) override
PrimExpr VisitExpr_(const NotNode *op) override
ExprVisitor.
Definition: expr_functor.h:206
void VisitExpr_(const FloorModNode *op) override
void VisitExpr_(const LENode *op) override
void VisitExpr_(const ProducerLoadNode *op) override
void VisitExpr_(const NENode *op) override
void VisitExpr_(const OrNode *op) override
void VisitExpr_(const LetNode *op) override
void VisitExpr_(const CastNode *op) override
void VisitExpr_(const MulNode *op) override
void VisitExpr_(const RampNode *op) override
void VisitExpr_(const IntImmNode *op) override
void VisitExpr_(const ModNode *op) override
void VisitExpr_(const DivNode *op) override
void VisitExpr_(const SelectNode *op) override
void VisitExpr_(const SubNode *op) override
void VisitExpr_(const StringImmNode *op) override
void VisitExpr_(const FloorDivNode *op) override
void VisitExpr_(const SizeVarNode *op) override
void VisitExpr_(const AnyNode *op) override
void VisitExpr_(const GTNode *op) override
void VisitExpr_(const MaxNode *op) override
void VisitExpr_(const NotNode *op) override
void VisitExpr_(const BroadcastNode *op) override
void VisitExpr_(const FloatImmNode *op) override
void VisitExpr_(const EQNode *op) override
void VisitExpr_(const AndNode *op) override
void VisitExpr_(const GENode *op) override
void VisitExpr_(const VarNode *op) override
void VisitExpr_(const AddNode *op) override
void VisitExpr_(const CallNode *op) override
void VisitExpr_(const ReduceNode *op) override
void VisitExpr_(const BufferLoadNode *op) override
void VisitExpr_(const ShuffleNode *op) override
void VisitExpr_(const LTNode *op) override
void VisitExpr_(const MinNode *op) override
Floor division, floor(a/b)
Definition: expr.h:248
The remainder of the floordiv.
Definition: expr.h:265
a >= b
Definition: expr.h:433
a > b
Definition: expr.h:416
a < b
Definition: expr.h:382
Let binding. Bind var to value then evaluate body.
Definition: expr.h:834
max(a, b)
Definition: expr.h:299
min(a, b)
Definition: expr.h:282
a % b in the C semnatics.
Definition: expr.h:231
a * b
Definition: expr.h:191
a != b
Definition: expr.h:365
!a
Definition: expr.h:530
a || b
Definition: expr.h:490
Load value from the result produced by the producer.
Definition: expr.h:697
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:747
Reduction operator.
Definition: expr.h:1039
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:572
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:930
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:144
String constants, only used in asserts.
Definition: expr.h:53
a - b
Definition: expr.h:174
A variable node in the IR.
Definition: var.h:48
Defines the Functor data structures.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
a <= b
Definition: expr.h:399
TIR expressions.
#define IR_EXPR_FUNCTOR_DISPATCH(OP)
Definition: expr_functor.h:81
#define EXPR_FUNCTOR_DEFAULT
Definition: expr_functor.h:78