tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/source_map.h>
24 #include <tvm/node/node.h>
25 #include <tvm/relax/type.h>
28 #include <tvm/runtime/object.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/op.h>
31 
32 #include <functional>
33 
34 namespace tvm {
35 namespace relax {
36 
37 using Expr = RelayExpr;
47 class IdNode : public Object {
48  public:
55 
56  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }
57 
58  bool SEqualReduce(const IdNode* other, SEqualReducer equal) const {
59  return equal.FreeVarEqualImpl(this, other);
60  }
61 
62  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); }
63 
64  static constexpr const char* _type_key = "relax.Id";
65  static constexpr const bool _type_has_method_sequal_reduce = true;
66  static constexpr const bool _type_has_method_shash_reduce = true;
68 };
69 
70 class Id : public ObjectRef {
71  public:
76  TVM_DLL explicit Id(String name_hint);
77 
79 };
80 
110 class StructInfoNode : public Object {
111  public:
116  mutable Span span;
117 
118  static constexpr const char* _type_key = "StructInfo";
119  static constexpr const bool _type_has_method_sequal_reduce = true;
120  static constexpr const bool _type_has_method_shash_reduce = true;
121  static constexpr const uint32_t _type_child_slots = 5;
123 };
124 
129 class StructInfo : public ObjectRef {
130  public:
132 };
133 
138 class CallNode : public ExprNode {
139  public:
147 
150 
153 
161 
163  v->Visit("op", &op);
164  v->Visit("args", &args);
165  v->Visit("attrs", &attrs);
166  v->Visit("sinfo_args", &sinfo_args);
167  v->Visit("struct_info_", &struct_info_);
168  v->Visit("_checked_type_", &checked_type_);
169  v->Visit("span", &span);
170  }
171 
172  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
173  // skip sinfo_args check for primitive ops.
174  return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
176  }
177 
178  void SHashReduce(SHashReducer hash_reduce) const {
179  hash_reduce(op);
180  hash_reduce(args);
181  hash_reduce(attrs);
182  hash_reduce(sinfo_args);
183  hash_reduce(struct_info_);
184  }
185 
186  static constexpr const char* _type_key = "relax.expr.Call";
188 };
189 
190 class Call : public Expr {
191  public:
200  TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
201  Array<StructInfo> sinfo_args = Array<StructInfo>(), Span span = Span());
202 
205 };
206 
213  Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
214  Optional<Attrs> opt_attrs = Optional<Attrs>(),
216  Optional<Span> opt_span = Optional<Span>());
217 
219 class TupleNode : public ExprNode {
220  public:
223 
225  v->Visit("fields", &fields);
226  v->Visit("_checked_type_", &checked_type_);
227  v->Visit("struct_info_", &struct_info_);
228  v->Visit("span", &span);
229  }
230 
231  bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
232  // struct info can be deterministically derived from fields.
233  return equal(fields, other->fields);
234  }
235 
236  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
237 
238  static constexpr const char* _type_key = "relax.expr.Tuple";
240 };
241 
242 class Tuple : public Expr {
243  public:
249  TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
250 
265  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
266  TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
267  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
268 
271 };
272 
279  Optional<Span> opt_span = Optional<Span>());
280 
282 class TupleGetItemNode : public ExprNode {
283  public:
287  int index;
288 
290  v->Visit("tuple_value", &tuple);
291  v->Visit("index", &index);
292  v->Visit("struct_info_", &struct_info_);
293  v->Visit("_checked_type_", &checked_type_);
294  v->Visit("span", &span);
295  }
296 
297  bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
298  // struct info can be deterministically tuple and index.
299  return equal(tuple, other->tuple) && equal(index, other->index);
300  }
301 
302  void SHashReduce(SHashReducer hash_reduce) const {
303  hash_reduce(tuple);
304  hash_reduce(index);
305  }
306 
307  static constexpr const char* _type_key = "relax.expr.TupleGetItem";
309 };
310 
311 class TupleGetItem : public Expr {
312  public:
319  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
320 
323 };
324 
331  Optional<Integer> opt_index = Optional<Integer>(),
332  Optional<Span> opt_span = Optional<Span>());
333 
338 class LeafExprNode : public ExprNode {
339  public:
340  static constexpr const char* _type_key = "relax.expr.LeafExpr";
341  static constexpr const uint32_t _type_child_slots = 7;
343 };
344 
349 class LeafExpr : public Expr {
350  public:
352 };
353 
356 class ShapeExprNode : public LeafExprNode {
357  public:
360 
362  v->Visit("values", &values);
363  v->Visit("struct_info_", &struct_info_);
364  v->Visit("_checked_type_", &checked_type_);
365  v->Visit("span", &span);
366  }
367 
368  bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const {
369  // struct info can be deterministically derived from values.
370  return equal(values, other->values);
371  }
372 
373  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); }
374 
375  static constexpr const char* _type_key = "relax.expr.ShapeExpr";
376  static constexpr const bool _type_has_method_sequal_reduce = true;
377  static constexpr const bool _type_has_method_shash_reduce = true;
379 };
380 
381 class ShapeExpr : public LeafExpr {
382  public:
383  TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
386 };
387 
389 class VarNode : public LeafExprNode {
390  public:
394 
396  const String& name_hint() const { return vid->name_hint; }
397 
399  v->Visit("vid", &vid);
400  v->Visit("struct_info_", &struct_info_);
401  v->Visit("_checked_type_", &checked_type_);
402  v->Visit("span", &span);
403  }
404 
405  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
406  equal->MarkGraphNode();
407  return equal(vid, other->vid) && equal(struct_info_, other->struct_info_);
408  }
409 
410  void SHashReduce(SHashReducer hash_reduce) const {
411  hash_reduce(vid);
412  hash_reduce(struct_info_);
413  }
414 
415  static constexpr const char* _type_key = "relax.expr.Var";
416  static constexpr const bool _type_has_method_sequal_reduce = true;
417  static constexpr const bool _type_has_method_shash_reduce = true;
418  static constexpr const uint32_t _type_child_slots = 2;
420 };
421 
422 class Var : public LeafExpr {
423  public:
424  TVM_DLL explicit Var(String name_hint, Optional<StructInfo> struct_info_annotation,
425  Span span = Span())
426  : Var(Id(name_hint), struct_info_annotation, span) {}
427 
428  TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span());
430 
432 };
433 
437 class DataflowVarNode : public VarNode {
438  public:
440  v->Visit("vid", &vid);
441  v->Visit("struct_info_", &struct_info_);
442  v->Visit("_checked_type_", &checked_type_);
443  v->Visit("span", &span);
444  }
445 
446  bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
447  equal->MarkGraphNode();
448  return equal(vid, other->vid) && equal(struct_info_, other->struct_info_);
449  }
450 
451  void SHashReduce(SHashReducer hash_reduce) const {
452  hash_reduce(vid);
453  hash_reduce(struct_info_);
454  }
455 
456  static constexpr const char* _type_key = "relax.expr.DataflowVar";
457  static constexpr const bool _type_has_method_sequal_reduce = true;
458  static constexpr const bool _type_has_method_shash_reduce = true;
460 };
461 
462 class DataflowVar : public Var {
463  public:
464  TVM_DLL explicit DataflowVar(String name_hint, Optional<StructInfo> struct_info_annotation,
465  Span span = Span())
466  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
467 
468  TVM_DLL explicit DataflowVar(Id vid, Optional<StructInfo> struct_info_annotation,
469  Span span = Span());
470 
473 };
474 
480 class ConstantNode : public LeafExprNode {
481  public:
484 
487 
489  bool is_scalar() const { return data->ndim == 0; }
490 
492  v->Visit("data", &data);
493  v->Visit("struct_info_", &struct_info_);
494  v->Visit("_checked_type_", &checked_type_);
495  v->Visit("span", &span);
496  }
497 
498  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
499  // struct info can be deterministically derived from data.
500  return equal(data, other->data) && equal(struct_info_, other->struct_info_);
501  }
502 
503  void SHashReduce(SHashReducer hash_reduce) const {
504  hash_reduce(data);
505  hash_reduce(struct_info_);
506  }
507 
508  static constexpr const char* _type_key = "relax.expr.Constant";
510 };
511 
512 class Constant : public LeafExpr {
513  public:
521  TVM_DLL explicit Constant(runtime::NDArray data,
522  Optional<StructInfo> struct_info_annotation = NullOpt,
523  Span span = Span());
524 
527 };
528 
534 class PrimValueNode : public LeafExprNode {
535  public:
538 
540  v->Visit("value", &value);
541  v->Visit("struct_info_", &struct_info_);
542  v->Visit("_checked_type_", &checked_type_);
543  v->Visit("span", &span);
544  }
545 
546  bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const {
547  // struct info can be deterministically derived from data.
548  return equal(value, other->value);
549  }
550 
551  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
552 
553  static constexpr const char* _type_key = "relax.expr.PrimValue";
555 };
556 
561 class PrimValue : public LeafExpr {
562  public:
568  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
569 
576  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
577 
580 };
581 
585 class StringImmNode : public LeafExprNode {
586  public:
589 
591  v->Visit("value", &value);
592  v->Visit("struct_info_", &struct_info_);
593  v->Visit("_checked_type_", &checked_type_);
594  v->Visit("span", &span);
595  }
596 
597  bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
598  // struct info can be deterministically derived from data.
599  return equal(value, other->value);
600  }
601 
602  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
603 
604  static constexpr const char* _type_key = "relax.expr.StringImm";
606 };
607 
612 class StringImm : public LeafExpr {
613  public:
619  TVM_DLL explicit StringImm(String value, Span span = Span());
620 
623 };
624 
629  public:
632 
634  v->Visit("value", &value);
635  v->Visit("struct_info_", &struct_info_);
636  v->Visit("_checked_type_", &checked_type_);
637  v->Visit("span", &span);
638  }
639 
640  bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const {
641  // struct info can be deterministically derived from data.
642  return equal(value, other->value);
643  }
644 
645  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
646 
647  static constexpr const char* _type_key = "relax.expr.DataTypeImm";
649 };
650 
655 class DataTypeImm : public LeafExpr {
656  public:
662  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
663 
666 };
667 
669 class BindingNode : public Object {
670  public:
673  mutable Span span;
674 
675  static constexpr const char* _type_key = "relax.expr.Binding";
676  static constexpr const bool _type_has_method_sequal_reduce = true;
677  static constexpr const bool _type_has_method_shash_reduce = true;
679 };
680 
681 class Binding : public ObjectRef {
682  protected:
683  Binding() = default;
684 
685  public:
686  explicit Binding(ObjectPtr<Object> n) : ObjectRef(n) {}
688  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
689  const BindingNode* get() const { return operator->(); }
691 };
692 
700 class MatchCastNode : public BindingNode {
701  public:
706 
708  v->Visit("var", &var);
709  v->Visit("value", &value);
710  v->Visit("struct_info", &struct_info);
711  v->Visit("span", &span);
712  }
713 
714  bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const;
715  void SHashReduce(SHashReducer hash_reduce) const;
716 
717  static constexpr const char* _type_key = "relax.expr.MatchCast";
718  static constexpr const bool _type_has_method_sequal_reduce = true;
719  static constexpr const bool _type_has_method_shash_reduce = true;
721 };
722 
727 class MatchCast : public Binding {
728  public:
729  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
730 
733 };
734 
735 class VarBindingNode : public BindingNode {
736  public:
739 
741  v->Visit("var", &var);
742  v->Visit("value", &value);
743  v->Visit("span", &span);
744  }
745 
746  bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const;
747  void SHashReduce(SHashReducer hash_reduce) const;
748 
749  static constexpr const char* _type_key = "relax.expr.VarBinding";
750  static constexpr const bool _type_has_method_sequal_reduce = true;
751  static constexpr const bool _type_has_method_shash_reduce = true;
753 };
754 
755 class VarBinding : public Binding {
756  public:
757  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
760 };
761 
762 class BindingBlockNode : public Object {
763  public:
764  mutable Span span;
766 
768  v->Visit("span", &span);
769  v->Visit("bindings", &bindings);
770  }
771 
772  bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
773  return equal(bindings, other->bindings);
774  }
775 
776  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
777 
778  static constexpr const char* _type_key = "relax.expr.BindingBlock";
779  static constexpr const bool _type_has_method_sequal_reduce = true;
780  static constexpr const bool _type_has_method_shash_reduce = true;
782 };
783 
784 class BindingBlock : public ObjectRef {
785  public:
786  TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
788 
790 };
791 
793  public:
794  bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
795  return equal(bindings, other->bindings);
796  }
797 
798  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
799 
800  static constexpr const char* _type_key = "relax.expr.DataflowBlock";
801  static constexpr const bool _type_has_method_sequal_reduce = true;
802  static constexpr const bool _type_has_method_shash_reduce = true;
804 };
805 
806 class DataflowBlock : public BindingBlock {
807  public:
808  TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
811 };
812 
817 class SeqExprNode : public ExprNode {
818  public:
821 
823  v->Visit("blocks", &blocks);
824  v->Visit("body", &body);
825  v->Visit("struct_info_", &struct_info_);
826  v->Visit("_checked_type_", &checked_type_);
827  v->Visit("span", &span);
828  }
829 
830  bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const {
831  return equal(blocks, other->blocks) && equal(body, other->body) &&
832  equal(struct_info_, other->struct_info_);
833  }
834 
835  void SHashReduce(SHashReducer hash_reduce) const {
836  hash_reduce(blocks);
837  hash_reduce(body);
838  hash_reduce(struct_info_);
839  }
840 
841  static constexpr const char* _type_key = "relax.expr.SeqExpr";
842  static constexpr const bool _type_has_method_sequal_reduce = true;
843  static constexpr const bool _type_has_method_shash_reduce = true;
845 };
846 
847 class SeqExpr : public Expr {
848  public:
849  /* \brief Implicit conversion constructor
850  *
851  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
852  * are required to be held as SeqExpr. This implicit conversion
853  * provides allows callsites to use these member variables when the
854  * C++ compile-time type is a `relax::Expr`. For example,
855  * a transform may use `func.CopyOnWrite()->body = expr;`.
856  *
857  * If the expression is already a `relax::SeqExpr`, the same
858  * underlying `relax::SeqExprNode` is used, and no copies are made.
859  */
860  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
861 
862  TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
865 };
866 
878 class IfNode : public ExprNode {
879  public:
886 
888  v->Visit("cond", &cond);
889  v->Visit("true_branch", &true_branch);
890  v->Visit("false_branch", &false_branch);
891  v->Visit("_checked_type_", &checked_type_);
892  v->Visit("struct_info_", &struct_info_);
893  v->Visit("span", &span);
894  }
895 
896  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
897  equal->MarkGraphNode();
898  return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
899  equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_);
900  }
901 
902  void SHashReduce(SHashReducer hash_reduce) const {
903  hash_reduce->MarkGraphNode();
904  hash_reduce(cond);
905  hash_reduce(true_branch);
906  hash_reduce(false_branch);
907  hash_reduce(struct_info_);
908  }
909 
910  static constexpr const char* _type_key = "relax.expr.If";
912 };
913 
914 class If : public Expr {
915  public:
933  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
934 
937 };
938 
945  Optional<Expr> opt_true_branch = Optional<Expr>(),
946  Optional<Expr> opt_false_branch = Optional<Expr>(),
947  Optional<Span> opt_span = Optional<Span>());
948 
950 class FunctionNode : public BaseFuncNode {
951  public:
959  bool is_pure;
960 
962  v->Visit("params", &params);
963  v->Visit("body", &body);
964  v->Visit("is_pure", &is_pure);
965  v->Visit("ret_struct_info", &ret_struct_info);
966  v->Visit("attrs", &attrs);
967  v->Visit("struct_info_", &struct_info_);
968  v->Visit("_checked_type_", &checked_type_);
969  v->Visit("span", &span);
970  }
971 
972  bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
973  equal->MarkGraphNode();
974  return equal.DefEqual(params, other->params) && equal(body, other->body) &&
975  equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) &&
976  equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_);
977  }
978 
979  void SHashReduce(SHashReducer hash_reduce) const {
980  hash_reduce->MarkGraphNode();
981  hash_reduce.DefHash(params);
982  hash_reduce(body);
983  hash_reduce(ret_struct_info);
984  hash_reduce(is_pure);
985  hash_reduce(attrs);
986  hash_reduce(struct_info_);
987  }
988 
989  static constexpr const char* _type_key = "relax.expr.Function";
990  static constexpr const bool _type_has_method_sequal_reduce = true;
991  static constexpr const bool _type_has_method_shash_reduce = true;
993 };
994 
995 class Function : public BaseFunc {
996  public:
1018  TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
1019  bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
1020 
1025  TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
1026  bool is_pure = true, DictAttrs attrs = DictAttrs(),
1027  Span span = Span());
1028 
1031 };
1032 
1033 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
1034 // kPrimitive.
1035 namespace attr {
1037 constexpr const char* kPrimitive = "Primitive";
1042 constexpr const char* kCodegen = "Codegen";
1044 constexpr const char* kComposite = "Composite";
1046 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
1048 constexpr const char* kWorkspaceSize = "WorkspaceSize";
1049 
1050 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
1051 // Past ones will be kept for backwards compatibility.
1054 constexpr const char* kForcePure = "relax.force_pure";
1055 
1061 constexpr const char* kNumInput = "num_input";
1062 } // namespace attr
1063 
1066  public:
1069 
1071  v->Visit("global_symbol", &global_symbol);
1072  v->Visit("struct_info_", &struct_info_);
1073  v->Visit("_checked_type_", &checked_type_);
1074  v->Visit("span", &span);
1075  }
1076 
1077  bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const {
1078  return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_);
1079  }
1080 
1081  void SHashReduce(SHashReducer hash_reduce) const {
1082  hash_reduce(global_symbol);
1083  hash_reduce(struct_info_);
1084  }
1085 
1086  static constexpr const char* _type_key = "relax.expr.ExternFunc";
1087  static constexpr const bool _type_has_method_sequal_reduce = true;
1088  static constexpr const bool _type_has_method_shash_reduce = true;
1090 };
1091 
1092 class ExternFunc : public BaseFunc {
1093  public:
1094  TVM_DLL ExternFunc(String global_symbol, Span span = Span());
1095  TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span());
1096 
1099 };
1100 
1112 TVM_DLL Expr GetShapeOf(const Expr& expr);
1113 
1114 } // namespace relax
1115 } // namespace tvm
1116 
1117 /* \brief Allow relax.Var as key in STL tables
1118  *
1119  * For most Relax expressions, it would be ambiguous whether the
1120  * expression should follow reference equality or structural equality.
1121  * This is not the case for variables, which do not contain nested
1122  * internal structure, and are frequently used as keys in lookup
1123  * tables.
1124  *
1125  * Providing `std::hash` and `std::equal_to` specializations for
1126  * `relax::Var` allows it to be used as a key in STL tables. For
1127  * `relax::Expr`, the user must specify the type of equality used
1128  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
1129  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
1130  */
1131 template <>
1132 struct std::hash<tvm::relax::Var> {
1133  std::size_t operator()(const tvm::relax::Var& var) const {
1134  return tvm::runtime::ObjectPtrHash()(var);
1135  }
1136 };
1137 
1138 template <>
1139 struct std::equal_to<tvm::relax::Var> {
1140  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
1141  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
1142  }
1143 };
1144 
1145 #endif // TVM_RELAX_EXPR_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Base node of all functions.
Definition: function.h:139
DictAttrs attrs
Additional attributes storing the meta-data.
Definition: function.h:142
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference to DictAttrsNode.
Definition: attrs.h:227
Reference to PrimExprNode.
Definition: expr.h:115
Base node of all non-primitive expressions.
Definition: expr.h:362
Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:377
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
Managed reference to RelayExprNode.
Definition: expr.h:442
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:203
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Managed reference to TensorTypeNode.
Definition: tensor_type.h:99
Definition: expr.h:762
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:776
bool SEqualReduce(const BindingBlockNode *other, SEqualReducer equal) const
Definition: expr.h:772
Array< Binding > bindings
Definition: expr.h:765
Span span
Definition: expr.h:764
TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:767
Definition: expr.h:784
BindingBlock(Array< Binding > bindings, Span span=Span())
BindingBlockNode * CopyOnWrite()
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode)
The base class of a variable binding in Relax.
Definition: expr.h:669
TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object)
Var var
The return variable to bound to.
Definition: expr.h:672
Span span
Definition: expr.h:673
Definition: expr.h:681
const BindingNode * operator->() const
Definition: expr.h:688
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
Binding(ObjectPtr< Object > n)
Definition: expr.h:686
const BindingNode * get() const
Definition: expr.h:689
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:138
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:162
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:172
tvm::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:149
static constexpr const char * _type_key
Definition: expr.h:186
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:178
Expr op
The operator(function) being invoked.
Definition: expr.h:146
Attrs attrs
The additional attributes.
Definition: expr.h:152
Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is designed to be non-empty only for intrinsic...
Definition: expr.h:160
Definition: expr.h:190
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(Expr op, Array< Expr > args, Attrs attrs=Attrs(), Array< StructInfo > sinfo_args=Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:480
runtime::NDArray data
The data of the tensor.
Definition: expr.h:483
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode)
bool SEqualReduce(const ConstantNode *other, SEqualReducer equal) const
Definition: expr.h:498
bool is_scalar() const
Definition: expr.h:489
TensorType tensor_type() const
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:503
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:491
Definition: expr.h:512
TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::NDArray data, Optional< StructInfo > struct_info_annotation=NullOpt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:628
DataType value
The data value.
Definition: expr.h:631
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:645
bool SEqualReduce(const DataTypeImmNode *other, SEqualReducer equal) const
Definition: expr.h:640
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:633
Managed reference to DataTypeImm.
Definition: expr.h:655
TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:792
bool SEqualReduce(const DataflowBlockNode *other, SEqualReducer equal) const
Definition: expr.h:794
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:798
Definition: expr.h:806
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
DataflowBlock(Array< Binding > bindings, Span span=Span())
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:437
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:439
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:451
bool SEqualReduce(const DataflowVarNode *other, SEqualReducer equal) const
Definition: expr.h:446
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode)
Definition: expr.h:462
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode)
DataflowVar(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:464
The extern function, which can represent packed function.
Definition: expr.h:1065
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1081
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1070
String global_symbol
The name of global symbol.
Definition: expr.h:1068
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode)
bool SEqualReduce(const ExternFuncNode *other, SEqualReducer equal) const
Definition: expr.h:1077
Definition: expr.h:1092
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(String global_symbol, StructInfo struct_info, Span span=Span())
ExternFunc(String global_symbol, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
A Relax function.
Definition: expr.h:950
Array< Var > params
The parameters to the function.
Definition: expr.h:953
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode)
SeqExpr body
The body of the function.
Definition: expr.h:955
bool SEqualReduce(const FunctionNode *other, SEqualReducer equal) const
Definition: expr.h:972
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:961
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:957
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:959
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:979
Definition: expr.h:995
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode)
static Function CreateEmpty(Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
Function(Array< Var > params, Expr body, Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
The unique identifier of variables.
Definition: expr.h:47
bool SEqualReduce(const IdNode *other, SEqualReducer equal) const
Definition: expr.h:58
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:66
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:56
static constexpr const char * _type_key
Definition: expr.h:64
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:65
String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:54
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:62
Definition: expr.h:70
Id(String name_hint)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode)
Condition expression.
Definition: expr.h:878
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:883
bool SEqualReduce(const IfNode *other, SEqualReducer equal) const
Definition: expr.h:896
Expr cond
The condition.
Definition: expr.h:881
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:902
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:885
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:887
Definition: expr.h:914
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode)
Base type of all (non-function) leaf Exprs.
Definition: expr.h:338
TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:349
TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:700
Expr value
The input value to match cast.
Definition: expr.h:703
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:707
bool SEqualReduce(const MatchCastNode *other, SEqualReducer equal) const
void SHashReduce(SHashReducer hash_reduce) const
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:705
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode)
Managed reference to MatchCastNode.
Definition: expr.h:727
TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:534
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:539
PrimExpr value
The prim expr representing the value.
Definition: expr.h:537
TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode)
bool SEqualReduce(const PrimValueNode *other, SEqualReducer equal) const
Definition: expr.h:546
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:551
Managed reference to PrimValueNode.
Definition: expr.h:561
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
PrimValue(PrimExpr value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode)
A sequence of blocks followed by an expression.
Definition: expr.h:817
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:822
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:835
bool SEqualReduce(const SeqExprNode *other, SEqualReducer equal) const
Definition: expr.h:830
Expr body
Definition: expr.h:820
TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode)
Array< BindingBlock > blocks
Definition: expr.h:819
Definition: expr.h:847
SeqExpr(Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:356
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode)
Array< PrimExpr > values
Definition: expr.h:359
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:373
bool SEqualReduce(const ShapeExprNode *other, SEqualReducer equal) const
Definition: expr.h:368
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:361
Definition: expr.h:381
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
ShapeExpr(Array< PrimExpr > values, Span span=Span())
Represent a string literal constant.
Definition: expr.h:585
bool SEqualReduce(const StringImmNode *other, SEqualReducer equal) const
Definition: expr.h:597
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:590
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:602
String value
The data value.
Definition: expr.h:588
Managed reference to StringImm.
Definition: expr.h:612
StringImm(String value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
Base type of all structure information.
Definition: expr.h:110
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:116
static constexpr const uint32_t _type_child_slots
Definition: expr.h:121
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:120
static constexpr const char * _type_key
Definition: expr.h:118
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:119
Managed reference to StructInfoNode.
Definition: expr.h:129
TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode)
Get index-th field out of a tuple.
Definition: expr.h:282
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:289
bool SEqualReduce(const TupleGetItemNode *other, SEqualReducer equal) const
Definition: expr.h:297
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:287
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:302
Expr tuple
The tuple Expression.
Definition: expr.h:285
Definition: expr.h:311
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:219
tvm::Array< Expr > fields
the fields of the tuple
Definition: expr.h:222
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode)
bool SEqualReduce(const TupleNode *other, SEqualReducer equal) const
Definition: expr.h:231
static constexpr const char * _type_key
Definition: expr.h:238
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:224
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:236
Definition: expr.h:242
Tuple(tvm::Array< Expr > fields, Span span=Span())
The constructor.
Tuple(tvm::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:266
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:735
Expr value
The binding value.
Definition: expr.h:738
TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:740
void SHashReduce(SHashReducer hash_reduce) const
bool SEqualReduce(const VarBindingNode *other, SEqualReducer equal) const
Definition: expr.h:755
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode)
The variable class for all Relax bindings.
Definition: expr.h:389
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:410
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:398
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode)
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:393
const String & name_hint() const
Definition: expr.h:396
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: expr.h:405
Definition: expr.h:422
TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode)
Var(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:424
Var(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
VarNode * CopyOnWrite()
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Base expr nodes in TVM.
Runtime Map container types.
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:1054
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:1048
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:1061
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:1044
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:1042
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:1037
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:1046
Call WithFields(Call call, Optional< Expr > opt_op=Optional< Expr >(), Optional< Array< Expr >> opt_args=Optional< Array< Expr >>(), Optional< Attrs > opt_attrs=Optional< Attrs >(), Optional< Array< StructInfo >> opt_sinfo_args=Optional< Array< StructInfo >>(), Optional< Span > opt_span=Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
tvm::Span Span
Definition: base.h:65
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
GlobalVar WithFields(GlobalVar global_var, Optional< String > opt_name_hint={}, Optional< Type > opt_type={}, Optional< VirtualDevice > opt_virtual_device={}, Optional< Span > opt_span={})
Returns global_var with the given properties. A null property denotes 'no change'....
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
ObjectRef equal functor.
Definition: object.h:665
ObjectRef hash functor.
Definition: object.h:655
TIR expressions.
Common operators defined for Expr.