tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_EXPR_H_
25 #define TVM_RELAY_EXPR_H_
26 
27 #include <tvm/ir/attrs.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/ir/module.h>
30 #include <tvm/ir/op.h>
31 
32 #include <functional>
33 #include <stack>
34 #include <string>
35 #include <utility>
36 
37 #include "./base.h"
38 #include "./type.h"
39 
40 namespace tvm {
41 namespace relay {
42 
49 using tvm::PrettyPrint;
50 
57 class Constant;
61 class ConstantNode : public ExprNode {
62  public:
65 
67  TensorType tensor_type() const;
68 
70  bool is_scalar() const { return data->ndim == 0; }
71 
73  v->Visit("data", &data);
74  v->Visit("span", &span);
75  v->Visit("_checked_type_", &checked_type_);
76  }
77 
78  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
79  return equal(data, other->data);
80  }
81 
82  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
83 
84  static constexpr const char* _type_key = "relay.Constant";
86 };
87 
88 class Constant : public Expr {
89  public:
95  TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
96 
98 };
99 
101 class Tuple;
103 class TupleNode : public ExprNode {
104  public:
107 
109  v->Visit("fields", &fields);
110  v->Visit("span", &span);
111  v->Visit("_checked_type_", &checked_type_);
112  }
113 
114  bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
115  // specially handle empty tuple as a constant is not a graph node.
116  if (fields.size() == other->fields.size() && fields.size() == 0) {
117  return true;
118  } else {
119  equal->MarkGraphNode();
120  return equal(fields, other->fields);
121  }
122  }
123 
124  void SHashReduce(SHashReducer hash_reduce) const {
125  if (fields.size() != 0) {
126  hash_reduce->MarkGraphNode();
127  hash_reduce(fields);
128  }
129  }
130 
131  static constexpr const char* _type_key = "relay.Tuple";
133 };
134 
135 class Tuple : public Expr {
136  public:
142  TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
143 
145 };
146 
155 class Var;
157 class VarNode : public ExprNode {
158  public:
175 
177  const String& name_hint() const { return vid->name_hint; }
178 
180  v->Visit("vid", &vid);
181  v->Visit("type_annotation", &type_annotation);
182  v->Visit("span", &span);
183  v->Visit("_checked_type_", &checked_type_);
184  }
185 
186  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
187  equal->MarkGraphNode();
188  return equal(type_annotation, other->type_annotation) && equal(vid, other->vid);
189  }
190 
191  void SHashReduce(SHashReducer hash_reduce) const {
192  hash_reduce->MarkGraphNode();
193  hash_reduce(type_annotation);
194  hash_reduce(vid);
195  }
196 
197  static constexpr const char* _type_key = "relay.Var";
199 };
200 
201 class Var : public Expr {
202  public:
209  TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span())
210  : Var(Id(name_hint), type_annotation, span) {}
211 
218  TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());
219 
221 };
222 
227 class Call;
229 class CallNode : public ExprNode {
230  protected:
231  // CallNode uses own deleter to indirectly call non-recursive destructor
233  static void Deleter_(Object* ptr);
234 
235  public:
243 
246 
249 
269 
271  v->Visit("op", &op);
272  v->Visit("args", &args);
273  v->Visit("attrs", &attrs);
274  v->Visit("type_args", &type_args);
275  v->Visit("span", &span);
276  v->Visit("_checked_type_", &checked_type_);
277  }
278 
279  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
280  // skip type_args check for primitive ops.
281  equal->MarkGraphNode();
282  return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
283  (IsPrimitiveOp(op) || equal(type_args, other->type_args));
284  }
285 
286  void SHashReduce(SHashReducer hash_reduce) const {
287  hash_reduce->MarkGraphNode();
288  hash_reduce(op);
289  hash_reduce(args);
290  hash_reduce(attrs);
291  if (!IsPrimitiveOp(op)) {
292  hash_reduce(type_args);
293  }
294  }
295 
296  static constexpr const char* _type_key = "relay.Call";
298  friend class Call;
299 };
300 
301 class Call : public Expr {
302  public:
306  ~Call();
307 
316  TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
317  Array<Type> type_args = Array<Type>(), Span span = Span());
318 
320 };
321 
333 class Let;
335 class LetNode : public ExprNode {
336  public:
343 
345  v->Visit("var", &var);
346  v->Visit("value", &value);
347  v->Visit("body", &body);
348  v->Visit("span", &span);
349  v->Visit("_checked_type_", &checked_type_);
350  }
351 
352  bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
353  equal->MarkGraphNode();
354  return equal.DefEqual(var, other->var) && equal(value, other->value) &&
355  equal(body, other->body);
356  }
357 
358  void SHashReduce(SHashReducer hash_reduce) const {
359  hash_reduce->MarkGraphNode();
360  hash_reduce.DefHash(var);
361  hash_reduce(value);
362  hash_reduce(body);
363  }
364 
365  static constexpr const char* _type_key = "relay.Let";
367 };
368 
369 class Let : public Expr {
370  public:
378  TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span());
379 
381 };
382 
394 class If;
396 class IfNode : public ExprNode {
397  public:
404 
406  v->Visit("cond", &cond);
407  v->Visit("true_branch", &true_branch);
408  v->Visit("false_branch", &false_branch);
409  v->Visit("span", &span);
410  v->Visit("_checked_type_", &checked_type_);
411  }
412 
413  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
414  equal->MarkGraphNode();
415  return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
416  equal(false_branch, other->false_branch);
417  }
418 
419  void SHashReduce(SHashReducer hash_reduce) const {
420  hash_reduce->MarkGraphNode();
421  hash_reduce(cond);
422  hash_reduce(true_branch);
423  hash_reduce(false_branch);
424  }
425 
426  static constexpr const char* _type_key = "relay.If";
428 };
429 
430 class If : public Expr {
431  public:
439  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
440 
442 };
443 
445 class TupleGetItem;
446 class TupleGetItemNode : public ExprNode {
447  public:
451  int index;
452 
454  v->Visit("tuple_value", &tuple);
455  v->Visit("index", &index);
456  v->Visit("span", &span);
457  v->Visit("_checked_type_", &checked_type_);
458  }
459 
460  bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
461  return equal(tuple, other->tuple) && equal(index, other->index);
462  }
463 
464  void SHashReduce(SHashReducer hash_reduce) const {
465  hash_reduce(tuple);
466  hash_reduce(index);
467  }
468 
469  static constexpr const char* _type_key = "relay.TupleGetItem";
471 };
472 
473 class TupleGetItem : public Expr {
474  public:
481  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
482 
484 };
485 
487 class RefCreate;
488 class RefCreateNode : public ExprNode {
489  public:
492 
494  v->Visit("value", &value);
495  v->Visit("span", &span);
496  v->Visit("_checked_type_", &checked_type_);
497  }
498 
499  bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
500  equal->MarkGraphNode();
501  return equal(value, other->value);
502  }
503 
504  void SHashReduce(SHashReducer hash_reduce) const {
505  hash_reduce->MarkGraphNode();
506  hash_reduce(value);
507  }
508 
509  static constexpr const char* _type_key = "relay.RefCreate";
511 };
512 
513 class RefCreate : public Expr {
514  public:
520  TVM_DLL explicit RefCreate(Expr value, Span span = Span());
521 
523 };
524 
526 class RefRead;
527 class RefReadNode : public ExprNode {
528  public:
531 
533  v->Visit("ref", &ref);
534  v->Visit("span", &span);
535  v->Visit("_checked_type_", &checked_type_);
536  }
537 
538  bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
539  equal->MarkGraphNode();
540  return equal(ref, other->ref);
541  }
542 
543  void SHashReduce(SHashReducer hash_reduce) const {
544  hash_reduce->MarkGraphNode();
545  hash_reduce(ref);
546  }
547 
548  static constexpr const char* _type_key = "relay.RefRead";
550 };
551 
552 class RefRead : public Expr {
553  public:
559  TVM_DLL explicit RefRead(Expr ref, Span span = Span());
560 
562 };
564 class RefWrite;
565 class RefWriteNode : public ExprNode {
566  public:
571 
573  v->Visit("ref", &ref);
574  v->Visit("value", &value);
575  v->Visit("span", &span);
576  v->Visit("_checked_type_", &checked_type_);
577  }
578 
579  bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
580  equal->MarkGraphNode();
581  return equal(ref, other->ref) && equal(value, other->value);
582  }
583 
584  void SHashReduce(SHashReducer hash_reduce) const {
585  hash_reduce->MarkGraphNode();
586  hash_reduce(ref);
587  hash_reduce(value);
588  }
589 
590  static constexpr const char* _type_key = "relay.RefWrite";
592 };
593 
594 class RefWrite : public Expr {
595  public:
602  TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span());
603 
605 };
606 
619 class TempExprNode : public ExprNode {
620  public:
622  virtual ~TempExprNode() {}
627  virtual Expr Realize() const = 0;
628 
629  static constexpr const char* _type_key = "relay.TempExpr";
630  static constexpr const bool _type_has_method_sequal_reduce = false;
631  static constexpr const bool _type_has_method_shash_reduce = false;
632  static constexpr const uint32_t _type_child_slots = 0;
634 };
635 
636 class TempExpr : public Expr {
637  public:
639 };
640 
641 } // namespace relay
642 } // namespace tvm
643 #endif // TVM_RELAY_EXPR_H_
tvm::Span Span
Definition: base.h:65
Definition: expr.h:594
Expr true_branch
The expression evaluated when condition is true.
Definition: expr.h:401
Definition: expr.h:135
bool SEqualReduce(const LetNode *other, SEqualReducer equal) const
Definition: expr.h:352
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:504
tvm::Array< relay::Expr > args
The arguments(inputs) of the call.
Definition: expr.h:245
tvm::BaseFuncNode BaseFuncNode
Definition: expr.h:46
Attrs attrs
The additional attributes.
Definition: expr.h:248
Base class of the temporary expression.
Definition: expr.h:619
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: expr.h:186
bool SEqualReduce(const TupleNode *other, SEqualReducer equal) const
Definition: expr.h:114
Managed reference to TensorTypeNode.
Definition: tensor_type.h:99
Definition: expr.h:636
Call container.
Definition: expr.h:229
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Base expr nodes in TVM.
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:56
IRModule that holds the functions and type definitions.
Var(String name_hint, Type type_annotation, Span span=Span())
The constructor.
Definition: expr.h:209
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Definition: expr.h:488
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:82
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:179
tvm::Array< relay::Expr > fields
the fields of the tuple
Definition: expr.h:106
Definition: expr.h:565
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Constant tensor type.
Definition: expr.h:61
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:419
bool is_scalar() const
Definition: expr.h:70
Expr cond
The condition.
Definition: expr.h:399
Definition: expr.h:552
bool SEqualReduce(const RefWriteNode *other, SEqualReducer equal) const
Definition: expr.h:579
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:572
Definition: expr.h:88
Type type_annotation
type annotaion of the variable. This field records user provided type annotation of the Var...
Definition: expr.h:174
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:453
Definition: expr.h:301
static constexpr const char * _type_key
Definition: expr.h:84
tvm::BaseFunc BaseFunc
Definition: expr.h:45
Expr body
The body of the let binding.
Definition: expr.h:342
base class of all object containers.
Definition: object.h:165
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:493
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
Container for Var.
Definition: expr.h:157
tvm::Array< Type > type_args
The type arguments passed to polymorphic(template) function.
Definition: expr.h:268
Expr ref
The Reference Expression.
Definition: expr.h:530
Helpers for attribute objects.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
Expr tuple
The tuple Expression.
Definition: expr.h:449
Primitive operators(builtin intrinsics) and registry for them.
Expr op
The operator(function) being invoked.
Definition: expr.h:242
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:286
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool SEqualReduce(const IfNode *other, SEqualReducer equal) const
Definition: expr.h:413
Definition: span.h:115
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:191
tvm::GlobalVarNode GlobalVarNode
Definition: expr.h:48
A binding of a sub-network.
Definition: expr.h:335
size_t size() const
Definition: array.h:399
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:52
Var var
The variable we bind to.
Definition: expr.h:338
Definition: expr.h:513
tvm::GlobalVar GlobalVar
Definition: expr.h:47
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:150
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:584
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:55
bool IsPrimitiveOp(const RelayExpr &expr)
Check that an expression is a "primitive operator".
Definition: op.h:509
Definition: base.h:112
Relay typed AST nodes.
void(* FDeleter)(Object *self)
Object deleter.
Definition: object.h:171
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:270
Managed reference to GlobalVarNode.
Definition: expr.h:220
Expr false_branch
The expression evaluated when condition is false.
Definition: expr.h:403
Tuple container.
Definition: expr.h:103
bool SEqualReduce(const ConstantNode *other, SEqualReducer equal) const
Definition: expr.h:78
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode)
Reference to string objects.
Definition: string.h:129
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:543
Managed reference to RelayExprNode.
Definition: expr.h:177
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
Object::FDeleter saved_deleter_
Definition: expr.h:232
Definition: expr.h:369
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:279
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode)
String PrettyPrint(const ObjectRef &node)
Pretty print a node for debug purposes.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Definition: expr.h:446
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:464
const String & name_hint() const
Definition: expr.h:177
TensorType tensor_type() const
bool SEqualReduce(const RefCreateNode *other, SEqualReducer equal) const
Definition: expr.h:499
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:124
Expr value
The initial value of the Reference.
Definition: expr.h:491
Definition: expr.h:473
virtual ~TempExprNode()
virtual destructor
Definition: expr.h:622
Base node of all functions.
Definition: function.h:77
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:344
Expr ref
The Reference Expression.
Definition: expr.h:568
Expr value
The value we bind var to.
Definition: expr.h:340
Managed reference to BaseFuncNode.
Definition: function.h:143
Base classes for the Relay IR.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:405
Managed reference to TypeNode.
Definition: type.h:93
Definition: expr.h:527
Definition: expr.h:201
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:72
Expr value
The value to write into.
Definition: expr.h:570
bool SEqualReduce(const TupleGetItemNode *other, SEqualReducer equal) const
Definition: expr.h:460
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:108
Global variable that lives in the top-level module.
Definition: expr.h:191
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:358
Base node of all non-primitive expressions.
Definition: expr.h:142
Definition: expr.h:430
int index
which value to get
Definition: expr.h:451
runtime::NDArray data
The data of the tensor.
Definition: expr.h:64
static constexpr const uint32_t _type_child_slots
Definition: expr.h:169
container of If
Definition: expr.h:396
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:532
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:178
bool SEqualReduce(const RefReadNode *other, SEqualReducer equal) const
Definition: expr.h:538
Id vid
The unique identifier of the Var.
Definition: expr.h:168