tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
adt.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_ADT_H_
25 #define TVM_RELAY_ADT_H_
26 
27 #include <tvm/ir/adt.h>
28 #include <tvm/ir/attrs.h>
29 #include <tvm/relay/base.h>
30 #include <tvm/relay/expr.h>
31 #include <tvm/relay/type.h>
32 
33 #include <functional>
34 #include <string>
35 #include <utility>
36 
37 namespace tvm {
38 namespace relay {
39 
42 
45 
47 class PatternNode : public RelayNode {
48  public:
49  static constexpr const char* _type_key = "relay.Pattern";
50  static constexpr const bool _type_has_method_sequal_reduce = true;
51  static constexpr const bool _type_has_method_shash_reduce = true;
53 };
54 
63 class Pattern : public ObjectRef {
64  public:
65  Pattern() {}
67 
69 };
70 
72 class PatternWildcard;
75  public:
76  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
77 
78  bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
79 
80  void SHashReduce(SHashReducer hash_reduce) const {}
81 
82  static constexpr const char* _type_key = "relay.PatternWildcard";
84 };
85 
86 class PatternWildcard : public Pattern {
87  public:
88  /* \brief Overload the default constructors. */
89  TVM_DLL PatternWildcard();
91  /* \brief Copy constructor. */
92  PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
93  /* \brief Move constructor. */
94  PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
95  /* \brief Copy assignment. */
96  PatternWildcard& operator=(const PatternWildcard& other) {
97  (*this).data_ = other.data_;
98  return *this;
99  }
100  /* \brief Move assignment. */
101  PatternWildcard& operator=(PatternWildcard&& other) {
102  (*this).data_ = std::move(other.data_);
103  return *this;
104  }
105 
107  return static_cast<const PatternWildcardNode*>(get());
108  }
109 
111 };
112 
114 class PatternVar;
116 class PatternVarNode : public PatternNode {
117  public:
120 
122  v->Visit("var", &var);
123  v->Visit("span", &span);
124  }
125 
126  bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
127  return equal.DefEqual(var, other->var);
128  }
129 
130  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
131 
132  static constexpr const char* _type_key = "relay.PatternVar";
134 };
135 
136 class PatternVar : public Pattern {
137  public:
142  TVM_DLL explicit PatternVar(tvm::relay::Var var);
143 
145 };
146 
148 class PatternConstructor;
151  public:
156 
158  v->Visit("constructor", &constructor);
159  v->Visit("patterns", &patterns);
160  v->Visit("span", &span);
161  }
162 
164  return equal(constructor, other->constructor) && equal(patterns, other->patterns);
165  }
166 
167  void SHashReduce(SHashReducer hash_reduce) const {
168  hash_reduce(constructor);
169  hash_reduce(patterns);
170  }
171 
172  static constexpr const char* _type_key = "relay.PatternConstructor";
174 };
175 
176 class PatternConstructor : public Pattern {
177  public:
183  TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
184 
186 };
187 
189 class PatternTuple;
192  public:
193  /* TODO(@jroesch): rename to field_pats */
196 
198  v->Visit("patterns", &patterns);
199  v->Visit("span", &span);
200  }
201 
202  bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
203  return equal(patterns, other->patterns);
204  }
205 
206  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
207 
208  static constexpr const char* _type_key = "relay.PatternTuple";
210 };
211 
212 class PatternTuple : public Pattern {
213  public:
218  TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
219 
221 };
222 
224 class Clause;
226 class ClauseNode : public Object {
227  public:
232 
234  v->Visit("lhs", &lhs);
235  v->Visit("rhs", &rhs);
236  }
237 
238  bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
239  return equal(lhs, other->lhs) && equal(rhs, other->rhs);
240  }
241 
242  void SHashReduce(SHashReducer hash_reduce) const {
243  hash_reduce(lhs);
244  hash_reduce(rhs);
245  }
246 
247  static constexpr const char* _type_key = "relay.Clause";
248  static constexpr const bool _type_has_method_sequal_reduce = true;
249  static constexpr const bool _type_has_method_shash_reduce = true;
251 };
252 
253 class Clause : public ObjectRef {
254  public:
260  TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
261 
264 };
265 
272  Optional<Expr> opt_rhs = Optional<Expr>());
273 
275 class Match;
277 class MatchNode : public ExprNode {
278  public:
281 
284 
288  bool complete;
289 
291  v->Visit("data", &data);
292  v->Visit("clauses", &clauses);
293  v->Visit("complete", &complete);
294  v->Visit("virtual_device_", &virtual_device_);
295  v->Visit("span", &span);
296  v->Visit("_checked_type_", &checked_type_);
297  }
298 
299  bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
300  equal->MarkGraphNode();
301  return equal(data, other->data) && equal(clauses, other->clauses) &&
302  equal(complete, other->complete);
303  }
304 
305  void SHashReduce(SHashReducer hash_reduce) const {
306  hash_reduce->MarkGraphNode();
307  hash_reduce(data);
308  hash_reduce(clauses);
309  hash_reduce(complete);
310  }
311 
312  static constexpr const char* _type_key = "relay.Match";
314 };
315 
316 class Match : public Expr {
317  public:
325  TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
326 
329 };
330 
337  Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(),
338  Optional<Bool> opt_complete = Optional<Bool>(),
339  Optional<Span> opt_span = Optional<Span>());
340 
341 } // namespace relay
342 } // namespace tvm
343 
344 #endif // TVM_RELAY_ADT_H_
bool SEqualReduce(const MatchNode *other, SEqualReducer equal) const
Definition: adt.h:299
tvm::Span Span
Definition: base.h:65
Definition: adt.h:176
Definition: adt.h:136
bool SEqualReduce(const PatternNode *other, SEqualReducer equal) const
Definition: adt.h:78
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Match container node.
Definition: adt.h:277
A custom smart pointer for Object.
Definition: object.h:358
Constructor constructor
Definition: adt.h:153
bool SEqualReduce(const PatternVarNode *other, SEqualReducer equal) const
Definition: adt.h:126
PatternWildcard(const PatternWildcard &pat)
Definition: adt.h:92
Pattern(ObjectPtr< tvm::Object > p)
Definition: adt.h:66
PatternVar container node.
Definition: adt.h:150
bool SEqualReduce(const PatternTupleNode *other, SEqualReducer equal) const
Definition: adt.h:202
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
bool SEqualReduce(const ClauseNode *other, SEqualReducer equal) const
Definition: adt.h:238
Expr data
The input being deconstructed.
Definition: adt.h:280
ADT constructor. Constructors compare by pointer equality.
Definition: adt.h:47
Relay expression language.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
const PatternWildcardNode * operator->() const
Definition: adt.h:106
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
tvm::relay::Var var
Variable that stores the matched value.
Definition: adt.h:119
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
tvm::TypeDataNode TypeDataNode
Definition: adt.h:44
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:305
Definition: loop_state.h:456
Span span
The location of the program in a SourceFragment can be null, check with span.defined() ...
Definition: base.h:75
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:76
static constexpr const bool _type_has_method_sequal_reduce
Definition: adt.h:50
Clause WithFields(Clause clause, Optional< Pattern > opt_lhs=Optional< Pattern >(), Optional< Expr > opt_rhs=Optional< Expr >())
Returns clause with the given properties. A null property denotes &#39;no change&#39;. Returns clause if all ...
base class of all object containers.
Definition: object.h:167
PatternWildcard & operator=(PatternWildcard &&other)
Definition: adt.h:101
Pattern()
Definition: adt.h:65
Managed reference to ConstructorNode.
Definition: adt.h:88
Helpers for attribute objects.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
Definition: adt.h:212
PatternVar container node.
Definition: adt.h:116
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:167
Expr rhs
The resulting value.
Definition: adt.h:231
This is the base node container of all relay structures.
Definition: base.h:71
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool SEqualReduce(const PatternConstructorNode *other, SEqualReducer equal) const
Definition: adt.h:163
Pattern lhs
The pattern the clause matches.
Definition: adt.h:229
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:157
Definition: source_map.h:120
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:206
Definition: adt.h:86
Definition: adt.h:316
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
tvm::Array< Pattern > patterns
Definition: adt.h:195
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Relay typed AST nodes.
static constexpr const char * _type_key
Definition: adt.h:49
TypeData container node.
Definition: adt.h:102
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:574
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:242
Managed reference to RelayExprNode.
Definition: expr.h:433
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:121
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Pattern is the base type for an ADT match pattern in Relay.
Definition: adt.h:63
Base type for declaring relay pattern.
Definition: adt.h:47
tvm::Array< Clause > clauses
The match node clauses.
Definition: adt.h:283
Algebraic data type definitions.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Base class of all object reference.
Definition: object.h:511
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object)
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
tvm::Constructor Constructor
Definition: adt.h:40
tvm::ConstructorNode ConstructorNode
Definition: adt.h:41
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
bool complete
Should this match be complete (cover all cases)? If yes, the type checker will generate an error if t...
Definition: adt.h:288
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:130
PatternWildcard(ObjectPtr< Object > n)
Definition: adt.h:90
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:197
static constexpr const bool _type_has_method_shash_reduce
Definition: adt.h:51
Stores all data for an Algebraic Data Type (ADT).
Definition: adt.h:149
Clause container node.
Definition: adt.h:226
Base classes for the Relay IR.
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
PatternWildcard(PatternWildcard &&pat)
Definition: adt.h:94
Definition: expr.h:234
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:290
tvm::Array< Pattern > patterns
Definition: adt.h:155
void VisitAttrs(tvm::AttrVisitor *v)
Definition: adt.h:233
tvm::TypeData TypeData
Definition: adt.h:43
Definition: adt.h:253
Base node of all non-primitive expressions.
Definition: expr.h:361
PatternWildcard & operator=(const PatternWildcard &other)
Definition: adt.h:96
void SHashReduce(SHashReducer hash_reduce) const
Definition: adt.h:80
PatternVar container node.
Definition: adt.h:191
PatternWildcard container node.
Definition: adt.h:74
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187