tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/container/map.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ir/expr.h>
26 #include <tvm/ir/function.h>
27 #include <tvm/ir/source_map.h>
28 #include <tvm/node/node.h>
29 #include <tvm/relax/type.h>
30 #include <tvm/runtime/object.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/op.h>
33 
34 #include <functional>
35 
36 namespace tvm {
37 namespace relax {
38 
39 using Expr = RelaxExpr;
49 class IdNode : public Object {
50  public:
56  String name_hint;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint,
61  refl::AttachFieldFlag::SEqHashIgnore());
62  }
63 
64  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
65  static constexpr const char* _type_key = "relax.Id";
66 
68 };
69 
70 class Id : public ObjectRef {
71  public:
76  TVM_DLL explicit Id(String name_hint);
77 
79 };
80 
110 class StructInfoNode : public Object {
111  public:
116  mutable Span span;
117 
118  static void RegisterReflection() {
119  namespace refl = tvm::ffi::reflection;
120  refl::ObjectDef<StructInfoNode>().def_ro("span", &StructInfoNode::span,
121  refl::AttachFieldFlag::SEqHashIgnore());
122  }
123 
124  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
125  static constexpr const char* _type_key = "ir.StructInfo";
126 
127  static constexpr const uint32_t _type_child_slots = 7;
129 };
130 
135 class StructInfo : public ObjectRef {
136  public:
138 };
139 
144 class CallNode : public ExprNode {
145  public:
153 
155  tvm::Array<Expr> args;
156 
159 
166  Array<StructInfo> sinfo_args;
167 
168  static void RegisterReflection() {
169  namespace refl = tvm::ffi::reflection;
170  refl::ObjectDef<CallNode>()
171  .def_ro("op", &CallNode::op)
172  .def_ro("args", &CallNode::args)
173  .def_ro("attrs", &CallNode::attrs)
174  .def_ro("sinfo_args", &CallNode::sinfo_args);
175  }
176 
177  static constexpr const char* _type_key = "relax.expr.Call";
179 };
180 
181 class Call : public Expr {
182  public:
191  TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
192  Array<StructInfo> sinfo_args = Array<StructInfo>(), Span span = Span());
193 
196 };
197 
203 Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
204  Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
205  Optional<Attrs> opt_attrs = Optional<Attrs>(),
206  Optional<Array<StructInfo>> opt_sinfo_args = Optional<Array<StructInfo>>(),
207  Optional<Span> opt_span = Optional<Span>());
208 
210 class TupleNode : public ExprNode {
211  public:
213  tvm::Array<Expr> fields;
214 
215  static void RegisterReflection() {
216  namespace refl = tvm::ffi::reflection;
217  refl::ObjectDef<TupleNode>().def_ro("fields", &TupleNode::fields);
218  }
219 
220  static constexpr const char* _type_key = "relax.expr.Tuple";
222 };
223 
224 class Tuple : public Expr {
225  public:
231  TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
232 
247  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
248  TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
249  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
250 
253 };
254 
260 Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
261  Optional<Span> opt_span = Optional<Span>());
262 
264 class TupleGetItemNode : public ExprNode {
265  public:
269  int index;
270 
271  static void RegisterReflection() {
272  namespace refl = tvm::ffi::reflection;
273  refl::ObjectDef<TupleGetItemNode>()
274  .def_ro("tuple_value", &TupleGetItemNode::tuple)
275  .def_ro("index", &TupleGetItemNode::index);
276  }
277 
278  static constexpr const char* _type_key = "relax.expr.TupleGetItem";
280 };
281 
282 class TupleGetItem : public Expr {
283  public:
290  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
291 
294 };
295 
301 TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
302  Optional<Integer> opt_index = Optional<Integer>(),
303  Optional<Span> opt_span = Optional<Span>());
304 
309 class LeafExprNode : public ExprNode {
310  public:
311  static constexpr const char* _type_key = "relax.expr.LeafExpr";
312  static constexpr const uint32_t _type_child_slots = 7;
314 };
315 
320 class LeafExpr : public Expr {
321  public:
323 };
324 
327 class ShapeExprNode : public LeafExprNode {
328  public:
330  Array<PrimExpr> values;
331 
332  static void RegisterReflection() {
333  namespace refl = tvm::ffi::reflection;
334  refl::ObjectDef<ShapeExprNode>().def_ro("values", &ShapeExprNode::values);
335  }
336 
337  static constexpr const char* _type_key = "relax.expr.ShapeExpr";
339 };
340 
341 class ShapeExpr : public LeafExpr {
342  public:
343  TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
346 };
347 
349 class VarNode : public LeafExprNode {
350  public:
354 
356  const String& name_hint() const { return vid->name_hint; }
357 
358  static void RegisterReflection() {
359  namespace refl = tvm::ffi::reflection;
360  refl::ObjectDef<VarNode>().def_ro("vid", &VarNode::vid);
361  // customize structural equal and hash to include struct_info_
362  refl::TypeAttrDef<VarNode>()
363  .def("__s_equal__", &VarNode::SEqual)
364  .def("__s_hash__", &VarNode::SHash);
365  }
366 
367  bool SEqual(const VarNode* other,
368  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
369  return equal(vid, other->vid, false, "vid") &&
370  equal(struct_info_, other->struct_info_, false, "struct_info_");
371  }
372 
373  uint64_t SHash(uint64_t init_hash,
374  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const {
375  uint64_t hash_value = init_hash;
376  hash_value = hash(vid, hash_value, false);
377  hash_value = hash(struct_info_, hash_value, false);
378  return hash_value;
379  }
380 
381  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
382  static constexpr const char* _type_key = "relax.expr.Var";
383  static constexpr const uint32_t _type_child_slots = 1;
385 };
386 
387 class Var : public LeafExpr {
388  public:
389  TVM_DLL explicit Var(String name_hint, Optional<StructInfo> struct_info_annotation,
390  Span span = Span())
391  : Var(Id(name_hint), struct_info_annotation, span) {}
392 
393  TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span());
395 
397 };
398 
402 class DataflowVarNode : public VarNode {
403  public:
404  static void RegisterReflection() {
405  namespace refl = tvm::ffi::reflection;
406  refl::ObjectDef<DataflowVarNode>();
407  }
408 
409  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
410  static constexpr const char* _type_key = "relax.expr.DataflowVar";
412 };
413 
414 class DataflowVar : public Var {
415  public:
416  TVM_DLL explicit DataflowVar(String name_hint, Optional<StructInfo> struct_info_annotation,
417  Span span = Span())
418  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
419 
420  TVM_DLL explicit DataflowVar(Id vid, Optional<StructInfo> struct_info_annotation,
421  Span span = Span());
422 
425 };
426 
432 class ConstantNode : public LeafExprNode {
433  public:
436 
439 
441  bool is_scalar() const { return data->ndim == 0; }
442 
443  static void RegisterReflection() {
444  namespace refl = tvm::ffi::reflection;
445  refl::ObjectDef<ConstantNode>().def_ro("data", &ConstantNode::data);
446  }
447 
448  static constexpr const char* _type_key = "relax.expr.Constant";
450 };
451 
452 class Constant : public LeafExpr {
453  public:
461  TVM_DLL explicit Constant(runtime::NDArray data,
462  Optional<StructInfo> struct_info_annotation = std::nullopt,
463  Span span = Span());
464 
467 };
468 
474 class PrimValueNode : public LeafExprNode {
475  public:
478 
479  static void RegisterReflection() {
480  namespace refl = tvm::ffi::reflection;
481  refl::ObjectDef<PrimValueNode>().def_ro("value", &PrimValueNode::value);
482  }
483 
484  static constexpr const char* _type_key = "relax.expr.PrimValue";
486 };
487 
492 class PrimValue : public LeafExpr {
493  public:
499  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
500 
507  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
508 
511 };
512 
516 class StringImmNode : public LeafExprNode {
517  public:
519  String value;
520 
521  static void RegisterReflection() {
522  namespace refl = tvm::ffi::reflection;
523  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
524  }
525 
526  static constexpr const char* _type_key = "relax.expr.StringImm";
528 };
529 
534 class StringImm : public LeafExpr {
535  public:
541  TVM_DLL explicit StringImm(String value, Span span = Span());
542 
545 };
546 
551  public:
554 
555  static void RegisterReflection() {
556  namespace refl = tvm::ffi::reflection;
557  refl::ObjectDef<DataTypeImmNode>().def_ro("value", &DataTypeImmNode::value);
558  }
559 
560  static constexpr const char* _type_key = "relax.expr.DataTypeImm";
562 };
563 
568 class DataTypeImm : public LeafExpr {
569  public:
575  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
576 
579 };
580 
582 class BindingNode : public Object {
583  public:
584  mutable Span span;
587 
588  static void RegisterReflection() {
589  namespace refl = tvm::ffi::reflection;
590  refl::ObjectDef<BindingNode>()
591  .def_ro("span", &BindingNode::span, refl::AttachFieldFlag::SEqHashIgnore())
592  .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef());
593  }
594 
595  static constexpr const char* _type_key = "relax.expr.Binding";
596  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
597 
599 };
600 
601 class Binding : public ObjectRef {
602  protected:
603  Binding() = default;
604 
605  public:
606  explicit Binding(ObjectPtr<Object> n) : ObjectRef(n) {}
608  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
609  const BindingNode* get() const { return operator->(); }
611 };
612 
620 class MatchCastNode : public BindingNode {
621  public:
626 
627  static void RegisterReflection() {
628  namespace refl = tvm::ffi::reflection;
629  refl::ObjectDef<MatchCastNode>()
630  .def_ro("value", &MatchCastNode::value)
631  .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef());
632  }
633 
634  static constexpr const char* _type_key = "relax.expr.MatchCast";
636 };
637 
642 class MatchCast : public Binding {
643  public:
644  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
645 
648 };
649 
650 class VarBindingNode : public BindingNode {
651  public:
654 
655  static void RegisterReflection() {
656  namespace refl = tvm::ffi::reflection;
657  refl::ObjectDef<VarBindingNode>().def_ro("value", &VarBindingNode::value);
658  // customize the SEqual and SHash methods for better error messages
659  refl::TypeAttrDef<VarBindingNode>()
660  .def("__s_equal__", &VarBindingNode::SEqual)
661  .def("__s_hash__", &VarBindingNode::SHash);
662  }
663 
664  bool SEqual(const VarBindingNode* other,
665  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const;
666  uint64_t SHash(uint64_t init_hash,
667  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const;
668 
669  static constexpr const char* _type_key = "relax.expr.VarBinding";
670 
672 };
673 
674 class VarBinding : public Binding {
675  public:
676  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
679 };
680 
681 class BindingBlockNode : public Object {
682  public:
683  Array<Binding> bindings;
684  mutable Span span;
685 
686  static void RegisterReflection() {
687  namespace refl = tvm::ffi::reflection;
688  refl::ObjectDef<BindingBlockNode>()
689  .def_ro("bindings", &BindingBlockNode::bindings)
690  .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(),
691  refl::DefaultValue(Span()));
692  }
693 
694  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
695  static constexpr const char* _type_key = "relax.expr.BindingBlock";
696 
698 };
699 
700 class BindingBlock : public ObjectRef {
701  public:
702  TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
704 
706 };
707 
709  public:
710  static void RegisterReflection() {
711  namespace refl = tvm::ffi::reflection;
712  refl::ObjectDef<DataflowBlockNode>();
713  }
714 
715  static constexpr const char* _type_key = "relax.expr.DataflowBlock";
716 
718 };
719 
720 class DataflowBlock : public BindingBlock {
721  public:
722  TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
725 };
726 
731 class SeqExprNode : public ExprNode {
732  public:
733  Array<BindingBlock> blocks;
735 
736  static void RegisterReflection() {
737  namespace refl = tvm::ffi::reflection;
738  refl::ObjectDef<SeqExprNode>()
739  .def_ro("blocks", &SeqExprNode::blocks)
740  .def_ro("body", &SeqExprNode::body);
741  }
742 
743  static constexpr const char* _type_key = "relax.expr.SeqExpr";
744 
746 };
747 
748 class SeqExpr : public Expr {
749  public:
750  /* \brief Implicit conversion constructor
751  *
752  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
753  * are required to be held as SeqExpr. This implicit conversion
754  * provides allows callsites to use these member variables when the
755  * C++ compile-time type is a `relax::Expr`. For example,
756  * a transform may use `func.CopyOnWrite()->body = expr;`.
757  *
758  * If the expression is already a `relax::SeqExpr`, the same
759  * underlying `relax::SeqExprNode` is used, and no copies are made.
760  */
761  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
762 
763  TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
766 };
767 
779 class IfNode : public ExprNode {
780  public:
787 
788  static void RegisterReflection() {
789  namespace refl = tvm::ffi::reflection;
790  refl::ObjectDef<IfNode>()
791  .def_ro("cond", &IfNode::cond)
792  .def_ro("true_branch", &IfNode::true_branch)
793  .def_ro("false_branch", &IfNode::false_branch);
794  }
795 
796  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
797  static constexpr const char* _type_key = "relax.expr.If";
799 };
800 
801 class If : public Expr {
802  public:
820  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
821 
824 };
825 
831 If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
832  Optional<Expr> opt_true_branch = Optional<Expr>(),
833  Optional<Expr> opt_false_branch = Optional<Expr>(),
834  Optional<Span> opt_span = Optional<Span>());
835 
837 class FunctionNode : public BaseFuncNode {
838  public:
840  Array<Var> params;
846  bool is_pure;
847 
848  static void RegisterReflection() {
849  namespace refl = tvm::ffi::reflection;
850  refl::ObjectDef<FunctionNode>()
851  .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDef())
852  .def_ro("body", &FunctionNode::body)
853  .def_ro("ret_struct_info", &FunctionNode::ret_struct_info)
854  .def_ro("is_pure", &FunctionNode::is_pure);
855  }
856 
857  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
858  static constexpr const char* _type_key = "relax.expr.Function";
860 };
861 
862 class Function : public BaseFunc {
863  public:
885  TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
886  bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());
887 
892  TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
893  bool is_pure = true, DictAttrs attrs = DictAttrs(),
894  Span span = Span());
895 
898 };
899 
900 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
901 // kPrimitive.
902 namespace attr {
904 constexpr const char* kPrimitive = "Primitive";
909 constexpr const char* kCodegen = "Codegen";
911 constexpr const char* kComposite = "Composite";
913 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
915 constexpr const char* kWorkspaceSize = "WorkspaceSize";
916 
917 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
918 // Past ones will be kept for backwards compatibility.
921 constexpr const char* kForcePure = "relax.force_pure";
922 
928 constexpr const char* kNumInput = "num_input";
929 } // namespace attr
930 
932 class ExternFuncNode : public BaseFuncNode {
933  public:
936 
937  static void RegisterReflection() {
938  namespace refl = tvm::ffi::reflection;
939  refl::ObjectDef<ExternFuncNode>().def_ro("global_symbol", &ExternFuncNode::global_symbol);
940  }
941 
942  static constexpr const char* _type_key = "relax.expr.ExternFunc";
944 };
945 
946 class ExternFunc : public BaseFunc {
947  public:
948  TVM_DLL ExternFunc(String global_symbol, Span span = Span());
949  TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span());
950 
953 };
954 
966 TVM_DLL Expr GetShapeOf(const Expr& expr);
967 
968 } // namespace relax
969 } // namespace tvm
970 
971 /* \brief Allow relax.Var as key in STL tables
972  *
973  * For most Relax expressions, it would be ambiguous whether the
974  * expression should follow reference equality or structural equality.
975  * This is not the case for variables, which do not contain nested
976  * internal structure, and are frequently used as keys in lookup
977  * tables.
978  *
979  * Providing `std::hash` and `std::equal_to` specializations for
980  * `relax::Var` allows it to be used as a key in STL tables. For
981  * `relax::Expr`, the user must specify the type of equality used
982  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
983  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
984  */
985 template <>
986 struct std::hash<tvm::relax::Var> {
987  std::size_t operator()(const tvm::relax::Var& var) const {
988  return tvm::runtime::ObjectPtrHash()(var);
989  }
990 };
991 
992 template <>
993 struct std::equal_to<tvm::relax::Var> {
994  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
995  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
996  }
997 };
998 
999 #endif // TVM_RELAX_EXPR_H_
Managed reference to BaseAttrsNode.
Definition: attrs.h:134
Base node of all functions.
Definition: function.h:139
Managed reference to BaseFuncNode.
Definition: function.h:234
Managed reference to DictAttrsNode.
Definition: attrs.h:166
Reference to PrimExprNode.
Definition: expr.h:129
Base node of all non-primitive expressions.
Definition: expr.h:422
Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:429
Managed reference to RelaxExprNode.
Definition: expr.h:446
Definition: source_map.h:113
Definition: expr.h:681
Array< Binding > bindings
Definition: expr.h:683
Span span
Definition: expr.h:684
static void RegisterReflection()
Definition: expr.h:686
TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object)
Definition: expr.h:700
BindingBlock(Array< Binding > bindings, Span span=Span())
BindingBlockNode * CopyOnWrite()
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode)
The base class of a variable binding in Relax.
Definition: expr.h:582
TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object)
Var var
The return variable to bound to.
Definition: expr.h:586
static void RegisterReflection()
Definition: expr.h:588
Span span
Definition: expr.h:584
Definition: expr.h:601
const BindingNode * operator->() const
Definition: expr.h:608
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
Binding(ObjectPtr< Object > n)
Definition: expr.h:606
const BindingNode * get() const
Definition: expr.h:609
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:144
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode)
static void RegisterReflection()
Definition: expr.h:168
tvm::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:155
static constexpr const char * _type_key
Definition: expr.h:177
Expr op
The operator(function) being invoked.
Definition: expr.h:152
Attrs attrs
The additional attributes.
Definition: expr.h:158
Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is designed to be non-empty only for intrinsic...
Definition: expr.h:166
Definition: expr.h:181
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(Expr op, Array< Expr > args, Attrs attrs=Attrs(), Array< StructInfo > sinfo_args=Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:432
runtime::NDArray data
The data of the tensor.
Definition: expr.h:435
static void RegisterReflection()
Definition: expr.h:443
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode)
bool is_scalar() const
Definition: expr.h:441
TensorType tensor_type() const
Definition: expr.h:452
TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::NDArray data, Optional< StructInfo > struct_info_annotation=std::nullopt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:550
DataType value
The data value.
Definition: expr.h:553
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:555
Managed reference to DataTypeImm.
Definition: expr.h:568
TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:708
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode)
static void RegisterReflection()
Definition: expr.h:710
Definition: expr.h:720
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
DataflowBlock(Array< Binding > bindings, Span span=Span())
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:402
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode)
static void RegisterReflection()
Definition: expr.h:404
Definition: expr.h:414
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode)
DataflowVar(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:416
The extern function, which can represent packed function.
Definition: expr.h:932
static void RegisterReflection()
Definition: expr.h:937
String global_symbol
The name of global symbol.
Definition: expr.h:935
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode)
Definition: expr.h:946
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(String global_symbol, StructInfo struct_info, Span span=Span())
ExternFunc(String global_symbol, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
A Relax function.
Definition: expr.h:837
static void RegisterReflection()
Definition: expr.h:848
Array< Var > params
The parameters to the function.
Definition: expr.h:840
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode)
SeqExpr body
The body of the function.
Definition: expr.h:842
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:844
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:846
Definition: expr.h:862
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode)
static Function CreateEmpty(Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
Function(Array< Var > params, Expr body, Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
The unique identifier of variables.
Definition: expr.h:49
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:64
static void RegisterReflection()
Definition: expr.h:58
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object)
static constexpr const char * _type_key
Definition: expr.h:65
String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:56
Definition: expr.h:70
Id(String name_hint)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode)
Condition expression.
Definition: expr.h:779
static void RegisterReflection()
Definition: expr.h:788
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:784
Expr cond
The condition.
Definition: expr.h:782
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode)
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:786
Definition: expr.h:801
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode)
Base type of all (non-function) leaf Exprs.
Definition: expr.h:309
TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:320
TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:620
Expr value
The input value to match cast.
Definition: expr.h:623
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:625
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode)
static void RegisterReflection()
Definition: expr.h:627
Managed reference to MatchCastNode.
Definition: expr.h:642
TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:474
static void RegisterReflection()
Definition: expr.h:479
PrimExpr value
The prim expr representing the value.
Definition: expr.h:477
TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode)
Managed reference to PrimValueNode.
Definition: expr.h:492
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
PrimValue(PrimExpr value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode)
A sequence of blocks followed by an expression.
Definition: expr.h:731
Expr body
Definition: expr.h:734
TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode)
static void RegisterReflection()
Definition: expr.h:736
Array< BindingBlock > blocks
Definition: expr.h:733
Definition: expr.h:748
SeqExpr(Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:327
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode)
Array< PrimExpr > values
Definition: expr.h:330
static void RegisterReflection()
Definition: expr.h:332
Definition: expr.h:341
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
ShapeExpr(Array< PrimExpr > values, Span span=Span())
Represent a string literal constant.
Definition: expr.h:516
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode)
String value
The data value.
Definition: expr.h:519
static void RegisterReflection()
Definition: expr.h:521
Managed reference to StringImm.
Definition: expr.h:534
StringImm(String value, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
Base type of all structure information.
Definition: expr.h:110
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:116
static constexpr const uint32_t _type_child_slots
Definition: expr.h:127
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:124
static void RegisterReflection()
Definition: expr.h:118
static constexpr const char * _type_key
Definition: expr.h:125
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object)
Managed reference to StructInfoNode.
Definition: expr.h:135
TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode)
Managed reference to TensorTypeNode.
Definition: type.h:98
Get index-th field out of a tuple.
Definition: expr.h:264
static void RegisterReflection()
Definition: expr.h:271
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:269
Expr tuple
The tuple Expression.
Definition: expr.h:267
Definition: expr.h:282
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:210
static void RegisterReflection()
Definition: expr.h:215
tvm::Array< Expr > fields
the fields of the tuple
Definition: expr.h:213
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode)
static constexpr const char * _type_key
Definition: expr.h:220
Definition: expr.h:224
Tuple(tvm::Array< Expr > fields, Span span=Span())
The constructor.
Tuple(tvm::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:248
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:650
static void RegisterReflection()
Definition: expr.h:655
Expr value
The binding value.
Definition: expr.h:653
TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode)
bool SEqual(const VarBindingNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:674
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode)
The variable class for all Relax bindings.
Definition: expr.h:349
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:373
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode)
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:353
static void RegisterReflection()
Definition: expr.h:358
const String & name_hint() const
Definition: expr.h:356
bool SEqual(const VarNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
Definition: expr.h:367
Definition: expr.h:387
TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode)
Var(String name_hint, Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:389
Var(Id vid, Optional< StructInfo > struct_info_annotation, Span span=Span())
VarNode * CopyOnWrite()
Runtime primitive data type.
Definition: data_type.h:47
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:53
Base expr nodes in TVM.
Function nodes.
Definition: repr_printer.h:91
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:921
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:915
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:928
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:911
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:909
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:904
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:913
If WithFields(If if_expr, Optional< Expr > opt_cond=Optional< Expr >(), Optional< Expr > opt_true_branch=Optional< Expr >(), Optional< Expr > opt_false_branch=Optional< Expr >(), Optional< Span > opt_span=Optional< Span >())
Returns if_expr with the given properties. A null property denotes 'no change'. Returns if_expr if al...
Call WithFields(Call call, Optional< Expr > opt_op=Optional< Expr >(), Optional< Array< Expr >> opt_args=Optional< Array< Expr >>(), Optional< Attrs > opt_attrs=Optional< Attrs >(), Optional< Array< StructInfo >> opt_sinfo_args=Optional< Array< StructInfo >>(), Optional< Span > opt_span=Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
TIR expressions.
Common operators defined for Expr.