tvm
adt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_ADT_H_
25 #define TVM_RELAY_ADT_H_
26 
27 #include <tvm/ir/adt.h>
28 #include <tvm/ir/attrs.h>
29 #include <tvm/relay/base.h>
30 #include <tvm/relay/expr.h>
31 #include <tvm/relay/type.h>
32 
33 #include <functional>
34 #include <string>
35 #include <utility>
36 
37 namespace tvm {
38 namespace relay {
39 
42 
45 
47 class PatternNode : public RelayNode {
48  public:
49  static constexpr const char* _type_key = "relay.Pattern";
50  static constexpr const bool _type_has_method_sequal_reduce = true;
51  static constexpr const bool _type_has_method_shash_reduce = true;
53 };
54 
63 class Pattern : public ObjectRef {
64  public:
65  Pattern() {}
67 
69 };
70 
72 class PatternWildcard;
75  public:
76  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
77 
78  bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
79 
80  void SHashReduce(SHashReducer hash_reduce) const {}
81 
82  static constexpr const char* _type_key = "relay.PatternWildcard";
84 };
85 
86 class PatternWildcard : public Pattern {
87  public:
88  /* \brief Overload the default constructors. */
89  TVM_DLL PatternWildcard();
91  /* \brief Copy constructor. */
93  /* \brief Move constructor. */
95  /* \brief Copy assignment. */
97  (*this).data_ = other.data_;
98  return *this;
99  }
100  /* \brief Move assignment. */
102  (*this).data_ = std::move(other.data_);
103  return *this;
104  }
105 
107  return static_cast<const PatternWildcardNode*>(get());
108  }
109 
111 };
112 
114 class PatternVar;
116 class PatternVarNode : public PatternNode {
117  public:
120 
122  v->Visit("var", &var);
123  v->Visit("span", &span);
124  }
125 
126  bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
127  return equal.DefEqual(var, other->var);
128  }
129 
130  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
131 
132  static constexpr const char* _type_key = "relay.PatternVar";
134 };
135 
136 class PatternVar : public Pattern {
137  public:
142  TVM_DLL explicit PatternVar(tvm::relay::Var var);
143 
145 };
146 
148 class PatternConstructor;
151  public:
156 
158  v->Visit("constructor", &constructor);
159  v->Visit("patterns", &patterns);
160  v->Visit("span", &span);
161  }
162 
164  return equal(constructor, other->constructor) && equal(patterns, other->patterns);
165  }
166 
167  void SHashReduce(SHashReducer hash_reduce) const {
168  hash_reduce(constructor);
169  hash_reduce(patterns);
170  }
171 
172  static constexpr const char* _type_key = "relay.PatternConstructor";
174 };
175 
176 class PatternConstructor : public Pattern {
177  public:
183  TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
184 
186 };
187 
189 class PatternTuple;
192  public:
193  /* TODO(@jroesch): rename to field_pats */
196 
198  v->Visit("patterns", &patterns);
199  v->Visit("span", &span);
200  }
201 
202  bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
203  return equal(patterns, other->patterns);
204  }
205 
206  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
207 
208  static constexpr const char* _type_key = "relay.PatternTuple";
210 };
211 
212 class PatternTuple : public Pattern {
213  public:
218  TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
219 
221 };
222 
224 class Clause;
226 class ClauseNode : public Object {
227  public:
232 
234  v->Visit("lhs", &lhs);
235  v->Visit("rhs", &rhs);
236  }
237 
238  bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
239  return equal(lhs, other->lhs) && equal(rhs, other->rhs);
240  }
241 
242  void SHashReduce(SHashReducer hash_reduce) const {
243  hash_reduce(lhs);
244  hash_reduce(rhs);
245  }
246 
247  static constexpr const char* _type_key = "relay.Clause";
248  static constexpr const bool _type_has_method_sequal_reduce = true;
249  static constexpr const bool _type_has_method_shash_reduce = true;
251 };
252 
253 class Clause : public ObjectRef {
254  public:
260  TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
261 
264 };
265 
272  Optional<Expr> opt_rhs = Optional<Expr>());
273 
275 class Match;
277 class MatchNode : public ExprNode {
278  public:
281 
284 
288  bool complete;
289 
291  v->Visit("data", &data);
292  v->Visit("clauses", &clauses);
293  v->Visit("complete", &complete);
294  v->Visit("virtual_device_", &virtual_device_);
295  v->Visit("span", &span);
296  v->Visit("_checked_type_", &checked_type_);
297  }
298 
299  bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
300  equal->MarkGraphNode();
301  return equal(data, other->data) && equal(clauses, other->clauses) &&
302  equal(complete, other->complete);
303  }
304 
305  void SHashReduce(SHashReducer hash_reduce) const {
306  hash_reduce->MarkGraphNode();
307  hash_reduce(data);
308  hash_reduce(clauses);
309  hash_reduce(complete);
310  }
311 
312  static constexpr const char* _type_key = "relay.Match";
314 };
315 
316 class Match : public Expr {
317  public:
325  TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
326 
329 };
330 
337  Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(),
338  Optional<Bool> opt_complete = Optional<Bool>(),
339  Optional<Span> opt_span = Optional<Span>());
340 
341 } // namespace relay
342 } // namespace tvm
343 
344 #endif // TVM_RELAY_ADT_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Managed reference to ConstructorNode.
Definition: adt.h:88
Base node of all non-primitive expressions.
Definition: expr.h:362
ObjectRef virtual_device_
The virtual device (VirtualDevice) for this node (the result of device planning). For first-order exp...
Definition: expr.h:418
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
Managed reference to RelayExprNode.
Definition: expr.h:442
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
TypeData container node.
Definition: adt.h:102
Stores all data for an Algebraic Data Type (ADT).
Definition: adt.h:149
Clause container node.
Definition: adt.h:226
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:242
static constexpr const bool _type_has_method_shash_reduce
Definition: adt.h:249
bool SEqualReduce(const ClauseNode *other, SEqualReducer equal) const
Definition: adt.h:238
static constexpr const bool _type_has_method_sequal_reduce
Definition: adt.h:248
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object)
Pattern lhs
The pattern the clause matches.
Definition: adt.h:229
Expr rhs
The resulting value.
Definition: adt.h:231
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:233
static constexpr const char * _type_key
Definition: adt.h:247
Definition: adt.h:253
Clause(Pattern lhs, Expr rhs)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode)
Match container node.
Definition: adt.h:277
bool complete
Should this match be complete (cover all cases)? If yes, the type checker will generate an error if t...
Definition: adt.h:288
tvm::Array< Clause > clauses
The match node clauses.
Definition: adt.h:283
Expr data
The input being deconstructed.
Definition: adt.h:280
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:290
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:305
static constexpr const char * _type_key
Definition: adt.h:312
bool SEqualReduce(const MatchNode *other, SEqualReducer equal) const
Definition: adt.h:299
Definition: adt.h:316
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode)
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode)
Match(Expr data, tvm::Array< Clause > clauses, bool complete=true, Span span=Span())
Constructor.
PatternVar container node.
Definition: adt.h:150
tvm::Array< Pattern > patterns
Definition: adt.h:155
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode)
bool SEqualReduce(const PatternConstructorNode *other, SEqualReducer equal) const
Definition: adt.h:163
Constructor constructor
Definition: adt.h:153
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:167
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:157
static constexpr const char * _type_key
Definition: adt.h:172
Definition: adt.h:176
PatternConstructor(Constructor constructor, tvm::Array< Pattern > patterns)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode)
Base type for declaring relay pattern.
Definition: adt.h:47
static constexpr const bool _type_has_method_shash_reduce
Definition: adt.h:51
static constexpr const bool _type_has_method_sequal_reduce
Definition: adt.h:50
static constexpr const char * _type_key
Definition: adt.h:49
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object)
PatternVar container node.
Definition: adt.h:191
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:197
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode)
static constexpr const char * _type_key
Definition: adt.h:208
tvm::Array< Pattern > patterns
Definition: adt.h:195
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:206
bool SEqualReduce(const PatternTupleNode *other, SEqualReducer equal) const
Definition: adt.h:202
Definition: adt.h:212
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode)
PatternTuple(tvm::Array< Pattern > patterns)
Constructor.
PatternVar container node.
Definition: adt.h:116
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:130
bool SEqualReduce(const PatternVarNode *other, SEqualReducer equal) const
Definition: adt.h:126
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:121
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode)
tvm::relay::Var var
Variable that stores the matched value.
Definition: adt.h:119
static constexpr const char * _type_key
Definition: adt.h:132
Definition: adt.h:136
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode)
PatternVar(tvm::relay::Var var)
Constructor.
PatternWildcard container node.
Definition: adt.h:74
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:80
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:76
bool SEqualReduce(const PatternNode *other, SEqualReducer equal) const
Definition: adt.h:78
static constexpr const char * _type_key
Definition: adt.h:82
Definition: adt.h:86
const PatternWildcardNode * operator->() const
Definition: adt.h:106
PatternWildcard(ObjectPtr< Object > n)
Definition: adt.h:90
PatternWildcard & operator=(PatternWildcard &&other)
Definition: adt.h:101
PatternWildcard(PatternWildcard &&pat)
Definition: adt.h:94
PatternWildcard(const PatternWildcard &pat)
Definition: adt.h:92
PatternWildcard & operator=(const PatternWildcard &other)
Definition: adt.h:96
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
Pattern()
Definition: adt.h:65
Pattern(ObjectPtr< tvm::Object > p)
Definition: adt.h:66
This is the base node container of all relay structures.
Definition: base.h:71
Span span
The location of the program in a SourceFragment can be null, check with span.defined()
Definition: base.h:75
Definition: expr.h:234
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Algebraic data type definitions.
Helpers for attribute objects.
tvm::TypeDataNode TypeDataNode
Definition: adt.h:44
tvm::ConstructorNode ConstructorNode
Definition: adt.h:41
tvm::TypeData TypeData
Definition: adt.h:43
Clause WithFields(Clause clause, Optional< Pattern > opt_lhs=Optional< Pattern >(), Optional< Expr > opt_rhs=Optional< Expr >())
Returns clause with the given properties. A null property denotes 'no change'. Returns clause if all ...
tvm::Constructor Constructor
Definition: adt.h:40
tvm::Span Span
Definition: base.h:65
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Base classes for the Relay IR.
Relay expression language.
Relay typed AST nodes.