tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/function.h>
24 #include <tvm/ir/source_map.h>
25 #include <tvm/node/node.h>
26 #include <tvm/relax/type.h>
29 #include <tvm/runtime/object.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/op.h>
32 
33 #include <functional>
34 
35 namespace tvm {
36 namespace relax {
37 
38 using Expr = RelaxExpr;
48 class IdNode : public Object {
49  public:
56 
57  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }
58 
59  bool SEqualReduce(const IdNode* other, SEqualReducer equal) const {
60  return equal.FreeVarEqualImpl(this, other);
61  }
62 
63  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); }
64 
65  static constexpr const char* _type_key = "relax.Id";
66  static constexpr const bool _type_has_method_sequal_reduce = true;
67  static constexpr const bool _type_has_method_shash_reduce = true;
69 };
70 
71 class Id : public ObjectRef {
72  public:
77  TVM_DLL explicit Id(String name_hint);
78 
80 };
81 
111 class StructInfoNode : public Object {
112  public:
117  mutable Span span;
118 
119  static constexpr const char* _type_key = "StructInfo";
120  static constexpr const bool _type_has_method_sequal_reduce = true;
121  static constexpr const bool _type_has_method_shash_reduce = true;
122  static constexpr const uint32_t _type_child_slots = 7;
124 };
125 
130 class StructInfo : public ObjectRef {
131  public:
133 };
134 
139 class CallNode : public ExprNode {
140  public:
148 
151 
154 
162 
164  v->Visit("op", &op);
165  v->Visit("args", &args);
166  v->Visit("attrs", &attrs);
167  v->Visit("sinfo_args", &sinfo_args);
168  v->Visit("struct_info_", &struct_info_);
169  v->Visit("_checked_type_", &checked_type_);
170  v->Visit("span", &span);
171  }
172 
173  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
174  // skip sinfo_args check for primitive ops.
175  return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
177  }
178 
179  void SHashReduce(SHashReducer hash_reduce) const {
180  hash_reduce(op);
181  hash_reduce(args);
182  hash_reduce(attrs);
183  hash_reduce(sinfo_args);
184  hash_reduce(struct_info_);
185  }
186 
187  static constexpr const char* _type_key = "relax.expr.Call";
189 };
190 
191 class Call : public Expr {
192  public:
201  TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
202  Array<StructInfo> sinfo_args = Array<StructInfo>(), Span span = Span());
203 
206 };
207 
214  Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
215  Optional<Attrs> opt_attrs = Optional<Attrs>(),
217  Optional<Span> opt_span = Optional<Span>());
218 
220 class TupleNode : public ExprNode {
221  public:
224 
226  v->Visit("fields", &fields);
227  v->Visit("_checked_type_", &checked_type_);
228  v->Visit("struct_info_", &struct_info_);
229  v->Visit("span", &span);
230  }
231 
232  bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
233  // struct info can be deterministically derived from fields.
234  return equal(fields, other->fields);
235  }
236 
237  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
238 
239  static constexpr const char* _type_key = "relax.expr.Tuple";
241 };
242 
243 class Tuple : public Expr {
244  public:
250  TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
251 
266  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
267  TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
268  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
269 
272 };
273 
280  Optional<Span> opt_span = Optional<Span>());
281 
283 class TupleGetItemNode : public ExprNode {
284  public:
288  int index;
289 
291  v->Visit("tuple_value", &tuple);
292  v->Visit("index", &index);
293  v->Visit("struct_info_", &struct_info_);
294  v->Visit("_checked_type_", &checked_type_);
295  v->Visit("span", &span);
296  }
297 
298  bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
299  // struct info can be deterministically tuple and index.
300  return equal(tuple, other->tuple) && equal(index, other->index);
301  }
302 
303  void SHashReduce(SHashReducer hash_reduce) const {
304  hash_reduce(tuple);
305  hash_reduce(index);
306  }
307 
308  static constexpr const char* _type_key = "relax.expr.TupleGetItem";
310 };
311 
312 class TupleGetItem : public Expr {
313  public:
320  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
321 
324 };
325 
332  Optional<Integer> opt_index = Optional<Integer>(),
333  Optional<Span> opt_span = Optional<Span>());
334 
339 class LeafExprNode : public ExprNode {
340  public:
341  static constexpr const char* _type_key = "relax.expr.LeafExpr";
342  static constexpr const uint32_t _type_child_slots = 7;
344 };
345 
350 class LeafExpr : public Expr {
351  public:
353 };
354 
357 class ShapeExprNode : public LeafExprNode {
358  public:
361 
363  v->Visit("values", &values);
364  v->Visit("struct_info_", &struct_info_);
365  v->Visit("_checked_type_", &checked_type_);
366  v->Visit("span", &span);
367  }
368 
369  bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const {
370  // struct info can be deterministically derived from values.
371  return equal(values, other->values);
372  }
373 
374  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); }
375 
376  static constexpr const char* _type_key = "relax.expr.ShapeExpr";
377  static constexpr const bool _type_has_method_sequal_reduce = true;
378  static constexpr const bool _type_has_method_shash_reduce = true;
380 };
381 
382 class ShapeExpr : public LeafExpr {
383  public:
384  TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
387 };
388 
390 class VarNode : public LeafExprNode {
391  public:
395 
397  const String& name_hint() const { return vid->name_hint; }
398 
400  v->Visit("vid", &vid);
401  v->Visit("struct_info_", &struct_info_);
402  v->Visit("_checked_type_", &checked_type_);
403  v->Visit("span", &span);
404  }
405 
406  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
407  equal->MarkGraphNode();
408  return equal(vid, other->vid) && equal(struct_info_, other->struct_info_);
409  }
410 
411  void SHashReduce(SHashReducer hash_reduce) const {
412  hash_reduce(vid);
413  hash_reduce(struct_info_);
414  }
415 
416  static constexpr const char* _type_key = "relax.expr.Var";
417  static constexpr const bool _type_has_method_sequal_reduce = true;
418  static constexpr const bool _type_has_method_shash_reduce = true;
419  static constexpr const uint32_t _type_child_slots = 1;
421 };
422 
423 class Var : public LeafExpr {
424  public:
425  TVM_DLL explicit Var(String name_hint, Optional<StructInfo> struct_info_annotation,
426  Span span = Span())
427  : Var(Id(name_hint), struct_info_annotation, span) {}
428 
429  TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span());
431 
433 };
434 
438 class DataflowVarNode : public VarNode {
439  public:
441  v->Visit("vid", &vid);
442  v->Visit("struct_info_", &struct_info_);
443  v->Visit("_checked_type_", &checked_type_);
444  v->Visit("span", &span);
445  }
446 
447  bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
448  equal->MarkGraphNode();
449  return equal(vid, other->vid) && equal(struct_info_, other->struct_info_);
450  }
451 
452  void SHashReduce(SHashReducer hash_reduce) const {
453  hash_reduce(vid);
454  hash_reduce(struct_info_);
455  }
456 
457  static constexpr const char* _type_key = "relax.expr.DataflowVar";
458  static constexpr const bool _type_has_method_sequal_reduce = true;
459  static constexpr const bool _type_has_method_shash_reduce = true;
461 };
462 
463 class DataflowVar : public Var {
464  public:
465  TVM_DLL explicit DataflowVar(String name_hint, Optional<StructInfo> struct_info_annotation,
466  Span span = Span())
467  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
468 
469  TVM_DLL explicit DataflowVar(Id vid, Optional<StructInfo> struct_info_annotation,
470  Span span = Span());
471 
474 };
475 
481 class ConstantNode : public LeafExprNode {
482  public:
485 
488 
490  bool is_scalar() const { return data->ndim == 0; }
491 
493  v->Visit("data", &data);
494  v->Visit("struct_info_", &struct_info_);
495  v->Visit("_checked_type_", &checked_type_);
496  v->Visit("span", &span);
497  }
498 
499  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
500  // struct info can be deterministically derived from data.
501  return equal(data, other->data) && equal(struct_info_, other->struct_info_);
502  }
503 
504  void SHashReduce(SHashReducer hash_reduce) const {
505  hash_reduce(data);
506  hash_reduce(struct_info_);
507  }
508 
509  static constexpr const char* _type_key = "relax.expr.Constant";
511 };
512 
513 class Constant : public LeafExpr {
514  public:
522  TVM_DLL explicit Constant(runtime::NDArray data,
523  Optional<StructInfo> struct_info_annotation = NullOpt,
524  Span span = Span());
525 
528 };
529 
535 class PrimValueNode : public LeafExprNode {
536  public:
539 
541  v->Visit("value", &value);
542  v->Visit("struct_info_", &struct_info_);
543  v->Visit("_checked_type_", &checked_type_);
544  v->Visit("span", &span);
545  }
546 
547  bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const {
548  // struct info can be deterministically derived from data.
549  return equal(value, other->value);
550  }
551 
552  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
553 
554  static constexpr const char* _type_key = "relax.expr.PrimValue";
556 };
557 
562 class PrimValue : public LeafExpr {
563  public:
569  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
570 
577  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
578 
581 };
582 
586 class StringImmNode : public LeafExprNode {
587  public:
590 
592  v->Visit("value", &value);
593  v->Visit("struct_info_", &struct_info_);
594  v->Visit("_checked_type_", &checked_type_);
595  v->Visit("span", &span);
596  }
597 
598  bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
599  // struct info can be deterministically derived from data.
600  return equal(value, other->value);
601  }
602 
603  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
604 
605  static constexpr const char* _type_key = "relax.expr.StringImm";
607 };
608 
613 class StringImm : public LeafExpr {
614  public:
620  TVM_DLL explicit StringImm(String value, Span span = Span());
621 
624 };
625 
630  public:
633 
635  v->Visit("value", &value);
636  v->Visit("struct_info_", &struct_info_);
637  v->Visit("_checked_type_", &checked_type_);
638  v->Visit("span", &span);
639  }
640 
641  bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const {
642  // struct info can be deterministically derived from data.
643  return equal(value, other->value);
644  }
645 
646  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
647 
648  static constexpr const char* _type_key = "relax.expr.DataTypeImm";
650 };
651 
656 class DataTypeImm : public LeafExpr {
657  public:
663  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
664 
667 };
668 
670 class BindingNode : public Object {
671  public:
674  mutable Span span;
675 
676  static constexpr const char* _type_key = "relax.expr.Binding";
677  static constexpr const bool _type_has_method_sequal_reduce = true;
678  static constexpr const bool _type_has_method_shash_reduce = true;
680 };
681 
682 class Binding : public ObjectRef {
683  protected:
684  Binding() = default;
685 
686  public:
687  explicit Binding(ObjectPtr<Object> n) : ObjectRef(n) {}
689  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
690  const BindingNode* get() const { return operator->(); }
692 };
693 
701 class MatchCastNode : public BindingNode {
702  public:
707 
709  v->Visit("var", &var);
710  v->Visit("value", &value);
711  v->Visit("struct_info", &struct_info);
712  v->Visit("span", &span);
713  }
714 
715  bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const;
716  void SHashReduce(SHashReducer hash_reduce) const;
717 
718  static constexpr const char* _type_key = "relax.expr.MatchCast";
719  static constexpr const bool _type_has_method_sequal_reduce = true;
720  static constexpr const bool _type_has_method_shash_reduce = true;
722 };
723 
728 class MatchCast : public Binding {
729  public:
730  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
731 
734 };
735 
736 class VarBindingNode : public BindingNode {
737  public:
740 
742  v->Visit("var", &var);
743  v->Visit("value", &value);
744  v->Visit("span", &span);
745  }
746 
747  bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const;
748  void SHashReduce(SHashReducer hash_reduce) const;
749 
750  static constexpr const char* _type_key = "relax.expr.VarBinding";
751  static constexpr const bool _type_has_method_sequal_reduce = true;
752  static constexpr const bool _type_has_method_shash_reduce = true;
754 };
755 
756 class VarBinding : public Binding {
757  public:
758  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
761 };
762 
763 class BindingBlockNode : public Object {
764  public:
765  mutable Span span;
767 
769  v->Visit("span", &span);
770  v->Visit("bindings", &bindings);
771  }
772 
773  bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
774  return equal(bindings, other->bindings);
775  }
776 
777  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
778 
779  static constexpr const char* _type_key = "relax.expr.BindingBlock";
780  static constexpr const bool _type_has_method_sequal_reduce = true;
781  static constexpr const bool _type_has_method_shash_reduce = true;
783 };
784 
785 class BindingBlock : public ObjectRef {
786  public:
787  TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
789 
791 };
792 
794  public:
795  bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
796  return equal(bindings, other->bindings);
797  }
798 
799  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
800 
801  static constexpr const char* _type_key = "relax.expr.DataflowBlock";
802  static constexpr const bool _type_has_method_sequal_reduce = true;
803  static constexpr const bool _type_has_method_shash_reduce = true;
805 };
806 
807 class DataflowBlock : public BindingBlock {
808  public:
809  TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
812 };
813 
818 class SeqExprNode : public ExprNode {
819  public:
822 
824  v->Visit("blocks", &blocks);
825  v->Visit("body", &body);
826  v->Visit("struct_info_", &struct_info_);
827  v->Visit("_checked_type_", &checked_type_);
828  v->Visit("span", &span);
829  }
830 
831  bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const {
832  return equal(blocks, other->blocks) && equal(body, other->body) &&
833  equal(struct_info_, other->struct_info_);
834  }
835 
836  void SHashReduce(SHashReducer hash_reduce) const {
837  hash_reduce(blocks);
838  hash_reduce(body);
839  hash_reduce(struct_info_);
840  }
841 
842  static constexpr const char* _type_key = "relax.expr.SeqExpr";
843  static constexpr const bool _type_has_method_sequal_reduce = true;
844  static constexpr const bool _type_has_method_shash_reduce = true;
846 };
847 
848 class SeqExpr : public Expr {
849  public:
850  /* \brief Implicit conversion constructor
851  *
852  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
853  * are required to be held as SeqExpr. This implicit conversion
854  * provides allows callsites to use these member variables when the
855  * C++ compile-time type is a `relax::Expr`. For example,
856  * a transform may use `func.CopyOnWrite()->body = expr;`.
857  *
858  * If the expression is already a `relax::SeqExpr`, the same
859  * underlying `relax::SeqExprNode` is used, and no copies are made.
860  */
861  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
862 
863  TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
866 };
867 
879 class IfNode : public ExprNode {
880  public:
887 
889  v->Visit("cond", &cond);
890  v->Visit("true_branch", &true_branch);
891  v->Visit("false_branch", &false_branch);
892  v->Visit("_checked_type_", &checked_type_);
893  v->Visit("struct_info_", &struct_info_);
894  v->Visit("span", &span);
895  }
896 
897  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
898  equal->MarkGraphNode();
899  return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
900  equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_);
901  }
902 
903  void SHashReduce(SHashReducer hash_reduce) const {
904  hash_reduce->MarkGraphNode();
905  hash_reduce(cond);
906  hash_reduce(true_branch);
907  hash_reduce(false_branch);
908  hash_reduce(struct_info_);
909  }
910 
911  static constexpr const char* _type_key = "relax.expr.If";
913 };
914 
915 class If : public Expr {
916  public:
934  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
935 
938 };
939 
946  Optional<Expr> opt_true_branch = Optional<Expr>(),
947  Optional<Expr> opt_false_branch = Optional<Expr>(),
948  Optional<Span> opt_span = Optional<Span>());
949 
951 class FunctionNode : public BaseFuncNode {
952  public:
960  bool is_pure;
961 
963  v->Visit("params", &params);
964  v->Visit("body", &body);
965  v->Visit("is_pure", &is_pure);
966  v->Visit("ret_struct_info", &ret_struct_info);
967  v->Visit("attrs", &attrs);
968  v->Visit("struct_info_", &struct_info_);
969  v->Visit("_checked_type_", &checked_type_);
970  v->Visit("span", &span);
971  }
972 
973  bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
974  equal->MarkGraphNode();
975  return equal.DefEqual(params, other->params) && equal(body, other->body) &&
976  equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) &&
977  equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_);
978  }
979 
980  void SHashReduce(SHashReducer hash_reduce) const {
981  hash_reduce->MarkGraphNode();
982  hash_reduce.DefHash(params);
983  hash_reduce(body);
984  hash_reduce(ret_struct_info);
985  hash_reduce(is_pure);
986  hash_reduce(attrs);
987  hash_reduce(struct_info_);
988  }
989 
990  static constexpr const char* _type_key = "relax.expr.Function";
991  static constexpr const bool _type_has_method_sequal_reduce = true;
992  static constexpr const bool _type_has_method_shash_reduce = true;
994 };
995 
996 class Function : public BaseFunc {
997  public:
1019  TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
1020  bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
1021 
1026  TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
1027  bool is_pure = true, DictAttrs attrs = DictAttrs(),
1028  Span span = Span());
1029 
1032 };
1033 
1034 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
1035 // kPrimitive.
1036 namespace attr {
1038 constexpr const char* kPrimitive = "Primitive";
1043 constexpr const char* kCodegen = "Codegen";
1045 constexpr const char* kComposite = "Composite";
1047 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
1049 constexpr const char* kWorkspaceSize = "WorkspaceSize";
1050 
1051 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
1052 // Past ones will be kept for backwards compatibility.
1055 constexpr const char* kForcePure = "relax.force_pure";
1056 
1062 constexpr const char* kNumInput = "num_input";
1063 } // namespace attr
1064 
1067  public:
1070 
1072  v->Visit("global_symbol", &global_symbol);
1073  v->Visit("struct_info_", &struct_info_);
1074  v->Visit("_checked_type_", &checked_type_);
1075  v->Visit("span", &span);
1076  }
1077 
1078  bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const {
1079  return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_);
1080  }
1081 
1082  void SHashReduce(SHashReducer hash_reduce) const {
1083  hash_reduce(global_symbol);
1084  hash_reduce(struct_info_);
1085  }
1086 
1087  static constexpr const char* _type_key = "relax.expr.ExternFunc";
1088  static constexpr const bool _type_has_method_sequal_reduce = true;
1089  static constexpr const bool _type_has_method_shash_reduce = true;
1091 };
1092 
1093 class ExternFunc : public BaseFunc {
1094  public:
1095  TVM_DLL ExternFunc(String global_symbol, Span span = Span());
1096  TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span());
1097 
1100 };
1101 
1113 TVM_DLL Expr GetShapeOf(const Expr& expr);
1114 
1115 } // namespace relax
1116 } // namespace tvm
1117 
1118 /* \brief Allow relax.Var as key in STL tables
1119  *
1120  * For most Relax expressions, it would be ambiguous whether the
1121  * expression should follow reference equality or structural equality.
1122  * This is not the case for variables, which do not contain nested
1123  * internal structure, and are frequently used as keys in lookup
1124  * tables.
1125  *
1126  * Providing `std::hash` and `std::equal_to` specializations for
1127  * `relax::Var` allows it to be used as a key in STL tables. For
1128  * `relax::Expr`, the user must specify the type of equality used
1129  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
1130  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
1131  */
1132 template <>
1133 struct std::hash<tvm::relax::Var> {
1134  std::size_t operator()(const tvm::relax::Var& var) const {
1135  return tvm::runtime::ObjectPtrHash()(var);
1136  }
1137 };
1138 
1139 template <>
1140 struct std::equal_to<tvm::relax::Var> {
1141  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
1142  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
1143  }
1144 };
1145 
1146 #endif // TVM_RELAX_EXPR_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Base node of all functions.
Definition: function.h:139
DictAttrs attrs
Additional attributes storing the meta-data.
Definition: function.h:142
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference to DictAttrsNode.
Definition: attrs.h:227
Reference to PrimExprNode.
Definition: expr.h:115
Base node of all non-primitive expressions.
Definition: expr.h:362
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:377
Managed reference to RelaxExprNode.
Definition: expr.h:405
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:203
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Definition: expr.h:763
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:777
bool SEqualReduce(const BindingBlockNode *other, SEqualReducer equal) const
Definition: expr.h:773
Array< Binding > bindings
Definition: expr.h:766
Span span
Definition: expr.h:765
TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:768
Definition: expr.h:785
BindingBlock(Array< Binding > bindings, Span span=Span())
BindingBlockNode * CopyOnWrite()
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode)
The base class of a variable binding in Relax.
Definition: expr.h:670
TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object)
Var var
The return variable to bound to.
Definition: expr.h:673
Span span
Definition: expr.h:674
Definition: expr.h:682
const BindingNode * operator->() const
Definition: expr.h:689
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
Binding(ObjectPtr< Object > n)
Definition: expr.h:687
const BindingNode * get() const
Definition: expr.h:690
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:139
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:163
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:173
tvm::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:150
static constexpr const char * _type_key
Definition: expr.h:187
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:179
Expr op
The operator(function) being invoked.
Definition: expr.h:147
Attrs attrs
The additional attributes.
Definition: expr.h:153
Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is designed to be non-empty only for intrinsic...
Definition: expr.h:161
Definition: expr.h:191
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(Expr op, Array< Expr > args, Attrs attrs=Attrs(), Array< StructInfo > sinfo_args=Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:481
runtime::NDArray data
The data of the tensor.
Definition: expr.h:484
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode)
bool SEqualReduce(const ConstantNode *other, SEqualReducer equal) const
Definition: expr.h:499
bool is_scalar() const
Definition: expr.h:490
TensorType tensor_type() const
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:504
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:492
Definition: expr.h:513
TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::NDArray data, Optional< StructInfo > struct_info_annotation=NullOpt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:629
DataType value
The data value.
Definition: expr.h:632
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:646
bool SEqualReduce(const DataTypeImmNode *other, SEqualReducer equal) const
Definition: expr.h:641
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:634
Managed reference to DataTypeImm.
Definition: expr.h:656
TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:793
bool SEqualReduce(const DataflowBlockNode *other, SEqualReducer equal) const
Definition: expr.h:795
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:799
Definition: expr.h:807
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
DataflowBlock(Array< Binding > bindings, Span span=Span())
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:438
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:440
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:452
bool SEqualReduce(const DataflowVarNode *other, SEqualReducer equal) const
Definition: expr.h:447
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode)
Definition: expr.h:463
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode)
DataflowVar(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:465
The extern function, which can represent packed function.
Definition: expr.h:1066
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1082
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1071
String global_symbol
The name of global symbol.
Definition: expr.h:1069
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode)
bool SEqualReduce(const ExternFuncNode *other, SEqualReducer equal) const
Definition: expr.h:1078
Definition: expr.h:1093
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(String global_symbol, StructInfo struct_info, Span span=Span())
ExternFunc(String global_symbol, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
A Relax function.
Definition: expr.h:951
Array< Var > params
The parameters to the function.
Definition: expr.h:954
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode)
SeqExpr body
The body of the function.
Definition: expr.h:956
bool SEqualReduce(const FunctionNode *other, SEqualReducer equal) const
Definition: expr.h:973
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:962
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:958
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:960
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:980
Definition: expr.h:996
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode)
static Function CreateEmpty(Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
Function(Array< Var > params, Expr body, Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
The unique identifier of variables.
Definition: expr.h:48
bool SEqualReduce(const IdNode *other, SEqualReducer equal) const
Definition: expr.h:59
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object)
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:67
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:57
static constexpr const char * _type_key
Definition: expr.h:65
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:66
String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:55
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:63
Definition: expr.h:71
Id(String name_hint)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode)
Condition expression.
Definition: expr.h:879
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:884
bool SEqualReduce(const IfNode *other, SEqualReducer equal) const
Definition: expr.h:897
Expr cond
The condition.
Definition: expr.h:882
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:903
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:886
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:888
Definition: expr.h:915
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode)
Base type of all (non-function) leaf Exprs.
Definition: expr.h:339
TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:350
TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:701
Expr value
The input value to match cast.
Definition: expr.h:704
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:708
bool SEqualReduce(const MatchCastNode *other, SEqualReducer equal) const
void SHashReduce(SHashReducer hash_reduce) const
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:706
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode)
Managed reference to MatchCastNode.
Definition: expr.h:728
TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:535
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:540
PrimExpr value
The prim expr representing the value.
Definition: expr.h:538
TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode)
bool SEqualReduce(const PrimValueNode *other, SEqualReducer equal) const
Definition: expr.h:547
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:552
Managed reference to PrimValueNode.
Definition: expr.h:562
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
PrimValue(PrimExpr value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode)
A sequence of blocks followed by an expression.
Definition: expr.h:818
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:823
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:836
bool SEqualReduce(const SeqExprNode *other, SEqualReducer equal) const
Definition: expr.h:831
Expr body
Definition: expr.h:821
TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode)
Array< BindingBlock > blocks
Definition: expr.h:820
Definition: expr.h:848
SeqExpr(Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:357
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode)
Array< PrimExpr > values
Definition: expr.h:360
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:374
bool SEqualReduce(const ShapeExprNode *other, SEqualReducer equal) const
Definition: expr.h:369
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:362
Definition: expr.h:382
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
ShapeExpr(Array< PrimExpr > values, Span span=Span())
Represent a string literal constant.
Definition: expr.h:586
bool SEqualReduce(const StringImmNode *other, SEqualReducer equal) const
Definition: expr.h:598
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:591
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:603
String value
The data value.
Definition: expr.h:589
Managed reference to StringImm.
Definition: expr.h:613
StringImm(String value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
Base type of all structure information.
Definition: expr.h:111
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:117
static constexpr const uint32_t _type_child_slots
Definition: expr.h:122
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:121
static constexpr const char * _type_key
Definition: expr.h:119
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:120
Managed reference to StructInfoNode.
Definition: expr.h:130
TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode)
Managed reference to TensorTypeNode.
Definition: type.h:111
Get index-th field out of a tuple.
Definition: expr.h:283
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:290
bool SEqualReduce(const TupleGetItemNode *other, SEqualReducer equal) const
Definition: expr.h:298
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:288
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:303
Expr tuple
The tuple Expression.
Definition: expr.h:286
Definition: expr.h:312
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:220
tvm::Array< Expr > fields
the fields of the tuple
Definition: expr.h:223
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode)
bool SEqualReduce(const TupleNode *other, SEqualReducer equal) const
Definition: expr.h:232
static constexpr const char * _type_key
Definition: expr.h:239
void VisitAttrs(tvm::AttrVisitor *v)
Definition: expr.h:225
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:237
Definition: expr.h:243
Tuple(tvm::Array< Expr > fields, Span span=Span())
The constructor.
Tuple(tvm::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:267
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:736
Expr value
The binding value.
Definition: expr.h:739
TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:741
void SHashReduce(SHashReducer hash_reduce) const
bool SEqualReduce(const VarBindingNode *other, SEqualReducer equal) const
Definition: expr.h:756
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode)
The variable class for all Relax bindings.
Definition: expr.h:390
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:411
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:399
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode)
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:394
const String & name_hint() const
Definition: expr.h:397
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: expr.h:406
Definition: expr.h:423
TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode)
Var(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:425
Var(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
VarNode * CopyOnWrite()
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
A custom smart pointer for Object.
Definition: object.h:363
Base class of all object reference.
Definition: object.h:520
base class of all object containers.
Definition: object.h:172
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:97
Base expr nodes in TVM.
Function nodes.
Runtime Map container types.
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:1055
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:1049
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:1062
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:1045
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:1043
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:1038
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:1047
If WithFields(If if_expr, Optional< Expr > opt_cond=Optional< Expr >(), Optional< Expr > opt_true_branch=Optional< Expr >(), Optional< Expr > opt_false_branch=Optional< Expr >(), Optional< Span > opt_span=Optional< Span >())
Returns if_expr with the given properties. A null property denotes 'no change'. Returns if_expr if al...
Call WithFields(Call call, Optional< Expr > opt_op=Optional< Expr >(), Optional< Array< Expr >> opt_args=Optional< Array< Expr >>(), Optional< Attrs > opt_attrs=Optional< Attrs >(), Optional< Array< StructInfo >> opt_sinfo_args=Optional< Array< StructInfo >>(), Optional< Span > opt_span=Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
ObjectRef equal functor.
Definition: object.h:666
ObjectRef hash functor.
Definition: object.h:656
TIR expressions.
Common operators defined for Expr.