tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/container/map.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ir/expr.h>
26 #include <tvm/ir/function.h>
27 #include <tvm/ir/source_map.h>
28 #include <tvm/node/node.h>
29 #include <tvm/relax/type.h>
30 #include <tvm/runtime/object.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/op.h>
33 
34 #include <functional>
35 
36 namespace tvm {
37 namespace relax {
38 
39 using Expr = RelaxExpr;
49 class IdNode : public Object {
50  public:
56  ffi::String name_hint;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint,
61  refl::AttachFieldFlag::SEqHashIgnore());
62  }
63 
64  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
66 };
67 
68 class Id : public ObjectRef {
69  public:
74  TVM_DLL explicit Id(ffi::String name_hint);
75 
77 };
78 
108 class StructInfoNode : public Object {
109  public:
114  mutable Span span;
115 
116  static void RegisterReflection() {
117  namespace refl = tvm::ffi::reflection;
118  refl::ObjectDef<StructInfoNode>().def_ro("span", &StructInfoNode::span,
119  refl::AttachFieldFlag::SEqHashIgnore());
120  }
121 
122  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
123 
124  static constexpr const uint32_t _type_child_slots = 7;
125  TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object);
126 };
127 
132 class StructInfo : public ObjectRef {
133  public:
135 };
136 
141 class CallNode : public ExprNode {
142  public:
150 
152  tvm::ffi::Array<Expr> args;
153 
156 
163  ffi::Array<StructInfo> sinfo_args;
164 
165  static void RegisterReflection() {
166  namespace refl = tvm::ffi::reflection;
167  refl::ObjectDef<CallNode>()
168  .def_ro("op", &CallNode::op)
169  .def_ro("args", &CallNode::args)
170  .def_ro("attrs", &CallNode::attrs)
171  .def_ro("sinfo_args", &CallNode::sinfo_args);
172  }
174 };
175 
176 class Call : public Expr {
177  public:
186  TVM_DLL Call(Expr op, ffi::Array<Expr> args, Attrs attrs = Attrs(),
187  ffi::Array<StructInfo> sinfo_args = ffi::Array<StructInfo>(), Span span = Span());
188 
191 };
192 
199  Call call, ffi::Optional<Expr> opt_op = ffi::Optional<Expr>(),
200  ffi::Optional<ffi::Array<Expr>> opt_args = ffi::Optional<ffi::Array<Expr>>(),
201  ffi::Optional<Attrs> opt_attrs = ffi::Optional<Attrs>(),
202  ffi::Optional<ffi::Array<StructInfo>> opt_sinfo_args = ffi::Optional<ffi::Array<StructInfo>>(),
203  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
204 
206 class TupleNode : public ExprNode {
207  public:
209  tvm::ffi::Array<Expr> fields;
210 
211  static void RegisterReflection() {
212  namespace refl = tvm::ffi::reflection;
213  refl::ObjectDef<TupleNode>().def_ro("fields", &TupleNode::fields);
214  }
216 };
217 
218 class Tuple : public Expr {
219  public:
225  TVM_DLL explicit Tuple(tvm::ffi::Array<Expr> fields, Span span = Span());
226 
241  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
242  TVM_DLL explicit Tuple(tvm::ffi::Array<RelaxExpr> fields, Span span = Span())
243  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
244 
247 };
248 
255  ffi::Optional<ffi::Array<Expr>> opt_fields = ffi::Optional<ffi::Array<Expr>>(),
256  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
257 
259 class TupleGetItemNode : public ExprNode {
260  public:
264  int index;
265 
266  static void RegisterReflection() {
267  namespace refl = tvm::ffi::reflection;
268  refl::ObjectDef<TupleGetItemNode>()
269  .def_ro("tuple_value", &TupleGetItemNode::tuple)
270  .def_ro("index", &TupleGetItemNode::index);
271  }
273 };
274 
275 class TupleGetItem : public Expr {
276  public:
283  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
284 
287 };
288 
295  ffi::Optional<Expr> opt_tuple = ffi::Optional<Expr>(),
296  ffi::Optional<Integer> opt_index = ffi::Optional<Integer>(),
297  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
298 
303 class LeafExprNode : public ExprNode {
304  public:
305  static constexpr const uint32_t _type_child_slots = 7;
307 };
308 
313 class LeafExpr : public Expr {
314  public:
316 };
317 
320 class ShapeExprNode : public LeafExprNode {
321  public:
323  ffi::Array<PrimExpr> values;
324 
325  static void RegisterReflection() {
326  namespace refl = tvm::ffi::reflection;
327  refl::ObjectDef<ShapeExprNode>().def_ro("values", &ShapeExprNode::values);
328  }
330 };
331 
332 class ShapeExpr : public LeafExpr {
333  public:
334  TVM_DLL explicit ShapeExpr(ffi::Array<PrimExpr> values, Span span = Span());
337 };
338 
340 class VarNode : public LeafExprNode {
341  public:
345 
347  const ffi::String& name_hint() const { return vid->name_hint; }
348 
349  static void RegisterReflection() {
350  namespace refl = tvm::ffi::reflection;
351  refl::ObjectDef<VarNode>().def_ro("vid", &VarNode::vid);
352  // customize structural equal and hash to include struct_info_
353  refl::TypeAttrDef<VarNode>()
354  .def("__s_equal__", &VarNode::SEqual)
355  .def("__s_hash__", &VarNode::SHash);
356  }
357 
358  bool SEqual(const VarNode* other,
359  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
360  return equal(vid, other->vid, false, "vid") &&
361  equal(struct_info_, other->struct_info_, false, "struct_info_");
362  }
363 
364  uint64_t SHash(uint64_t init_hash,
365  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const {
366  uint64_t hash_value = init_hash;
367  hash_value = hash(vid, hash_value, false);
368  hash_value = hash(struct_info_, hash_value, false);
369  return hash_value;
370  }
371 
372  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
373  static constexpr const uint32_t _type_child_slots = 1;
375 };
376 
377 class Var : public LeafExpr {
378  public:
379  TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional<StructInfo> struct_info_annotation,
380  Span span = Span())
381  : Var(Id(name_hint), struct_info_annotation, span) {}
382 
383  TVM_DLL explicit Var(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
384  Span span = Span());
386 
388 };
389 
393 class DataflowVarNode : public VarNode {
394  public:
395  static void RegisterReflection() {
396  namespace refl = tvm::ffi::reflection;
397  refl::ObjectDef<DataflowVarNode>();
398  }
399 
400  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
402 };
403 
404 class DataflowVar : public Var {
405  public:
406  TVM_DLL explicit DataflowVar(ffi::String name_hint,
407  ffi::Optional<StructInfo> struct_info_annotation, Span span = Span())
408  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
409 
410  TVM_DLL explicit DataflowVar(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
411  Span span = Span());
412 
415 };
416 
422 class ConstantNode : public LeafExprNode {
423  public:
426 
429 
431  bool is_scalar() const { return data->ndim == 0; }
432 
433  static void RegisterReflection() {
434  namespace refl = tvm::ffi::reflection;
435  refl::ObjectDef<ConstantNode>().def_ro("data", &ConstantNode::data);
436  }
438 };
439 
440 class Constant : public LeafExpr {
441  public:
449  TVM_DLL explicit Constant(runtime::Tensor data,
450  ffi::Optional<StructInfo> struct_info_annotation = std::nullopt,
451  Span span = Span());
452 
455 };
456 
462 class PrimValueNode : public LeafExprNode {
463  public:
466 
467  static void RegisterReflection() {
468  namespace refl = tvm::ffi::reflection;
469  refl::ObjectDef<PrimValueNode>().def_ro("value", &PrimValueNode::value);
470  }
472 };
473 
478 class PrimValue : public LeafExpr {
479  public:
485  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
486 
493  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
494 
497 };
498 
502 class StringImmNode : public LeafExprNode {
503  public:
505  ffi::String value;
506 
507  static void RegisterReflection() {
508  namespace refl = tvm::ffi::reflection;
509  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
510  }
512 };
513 
518 class StringImm : public LeafExpr {
519  public:
525  TVM_DLL explicit StringImm(ffi::String value, Span span = Span());
526 
529 };
530 
535  public:
538 
539  static void RegisterReflection() {
540  namespace refl = tvm::ffi::reflection;
541  refl::ObjectDef<DataTypeImmNode>().def_ro("value", &DataTypeImmNode::value);
542  }
544 };
545 
550 class DataTypeImm : public LeafExpr {
551  public:
557  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
558 
561 };
562 
564 class BindingNode : public Object {
565  public:
566  mutable Span span;
569 
570  static void RegisterReflection() {
571  namespace refl = tvm::ffi::reflection;
572  refl::ObjectDef<BindingNode>()
573  .def_ro("span", &BindingNode::span, refl::AttachFieldFlag::SEqHashIgnore())
574  .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef());
575  }
576 
577  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
578 
579  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object);
580 };
581 
582 class Binding : public ObjectRef {
583  protected:
584  Binding() = default;
585 
586  public:
587  explicit Binding(ObjectPtr<BindingNode> n) : ObjectRef(n) {}
588  explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {}
590  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
591  const BindingNode* get() const { return operator->(); }
593 };
594 
602 class MatchCastNode : public BindingNode {
603  public:
608 
609  static void RegisterReflection() {
610  namespace refl = tvm::ffi::reflection;
611  refl::ObjectDef<MatchCastNode>()
612  .def_ro("value", &MatchCastNode::value)
613  .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef());
614  }
616 };
617 
622 class MatchCast : public Binding {
623  public:
624  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
625 
628 };
629 
630 class VarBindingNode : public BindingNode {
631  public:
634 
635  static void RegisterReflection() {
636  namespace refl = tvm::ffi::reflection;
637  refl::ObjectDef<VarBindingNode>().def_ro("value", &VarBindingNode::value);
638  // customize the SEqual and SHash methods for better error messages
639  refl::TypeAttrDef<VarBindingNode>()
640  .def("__s_equal__", &VarBindingNode::SEqual)
641  .def("__s_hash__", &VarBindingNode::SHash);
642  }
643 
644  bool SEqual(const VarBindingNode* other,
645  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const;
646  uint64_t SHash(uint64_t init_hash,
647  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const;
649 };
650 
651 class VarBinding : public Binding {
652  public:
653  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
656 };
657 
658 class BindingBlockNode : public Object {
659  public:
660  ffi::Array<Binding> bindings;
661  mutable Span span;
662 
663  static void RegisterReflection() {
664  namespace refl = tvm::ffi::reflection;
665  refl::ObjectDef<BindingBlockNode>()
666  .def_ro("bindings", &BindingBlockNode::bindings)
667  .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(),
668  refl::DefaultValue(Span()));
669  }
670 
671  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
672  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object);
673 };
674 
675 class BindingBlock : public ObjectRef {
676  public:
677  TVM_DLL explicit BindingBlock(ffi::Array<Binding> bindings, Span span = Span());
679 
681 };
682 
684  public:
685  static void RegisterReflection() {
686  namespace refl = tvm::ffi::reflection;
687  refl::ObjectDef<DataflowBlockNode>();
688  }
691 };
692 
693 class DataflowBlock : public BindingBlock {
694  public:
695  TVM_DLL explicit DataflowBlock(ffi::Array<Binding> bindings, Span span = Span());
698 };
699 
704 class SeqExprNode : public ExprNode {
705  public:
706  ffi::Array<BindingBlock> blocks;
708 
709  static void RegisterReflection() {
710  namespace refl = tvm::ffi::reflection;
711  refl::ObjectDef<SeqExprNode>()
712  .def_ro("blocks", &SeqExprNode::blocks)
713  .def_ro("body", &SeqExprNode::body);
714  }
716 };
717 
718 class SeqExpr : public Expr {
719  public:
720  /* \brief Implicit conversion constructor
721  *
722  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
723  * are required to be held as SeqExpr. This implicit conversion
724  * provides allows callsites to use these member variables when the
725  * C++ compile-time type is a `relax::Expr`. For example,
726  * a transform may use `func.CopyOnWrite()->body = expr;`.
727  *
728  * If the expression is already a `relax::SeqExpr`, the same
729  * underlying `relax::SeqExprNode` is used, and no copies are made.
730  */
731  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
732 
733  TVM_DLL explicit SeqExpr(ffi::Array<BindingBlock> blocks, Expr body, Span span = Span());
736 };
737 
749 class IfNode : public ExprNode {
750  public:
757 
758  static void RegisterReflection() {
759  namespace refl = tvm::ffi::reflection;
760  refl::ObjectDef<IfNode>()
761  .def_ro("cond", &IfNode::cond)
762  .def_ro("true_branch", &IfNode::true_branch)
763  .def_ro("false_branch", &IfNode::false_branch);
764  }
765 
766  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
768 };
769 
770 class If : public Expr {
771  public:
789  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
790 
793 };
794 
800 If WithFields(If if_expr, ffi::Optional<Expr> opt_cond = ffi::Optional<Expr>(),
801  ffi::Optional<Expr> opt_true_branch = ffi::Optional<Expr>(),
802  ffi::Optional<Expr> opt_false_branch = ffi::Optional<Expr>(),
803  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
804 
806 class FunctionNode : public BaseFuncNode {
807  public:
809  ffi::Array<Var> params;
815  bool is_pure;
816 
817  static void RegisterReflection() {
818  namespace refl = tvm::ffi::reflection;
819  refl::ObjectDef<FunctionNode>()
820  .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDef())
821  .def_ro("body", &FunctionNode::body)
822  .def_ro("ret_struct_info", &FunctionNode::ret_struct_info)
823  .def_ro("is_pure", &FunctionNode::is_pure);
824  }
825 
826  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
828 };
829 
830 class Function : public BaseFunc {
831  public:
853  TVM_DLL explicit Function(ffi::Array<Var> params, Expr body,
854  ffi::Optional<StructInfo> ret_struct_info, bool is_pure = true,
855  DictAttrs attrs = DictAttrs(), Span span = Span());
856 
861  TVM_DLL static Function CreateEmpty(ffi::Array<Var> params, StructInfo ret_struct_info,
862  bool is_pure = true, DictAttrs attrs = DictAttrs(),
863  Span span = Span());
864 
867 };
868 
869 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
870 // kPrimitive.
871 namespace attr {
873 constexpr const char* kPrimitive = "Primitive";
878 constexpr const char* kCodegen = "Codegen";
880 constexpr const char* kComposite = "Composite";
882 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
884 constexpr const char* kWorkspaceSize = "WorkspaceSize";
885 
886 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
887 // Past ones will be kept for backwards compatibility.
890 constexpr const char* kForcePure = "relax.force_pure";
891 
897 constexpr const char* kNumInput = "num_input";
898 } // namespace attr
899 
901 class ExternFuncNode : public BaseFuncNode {
902  public:
904  ffi::String global_symbol;
905 
906  static void RegisterReflection() {
907  namespace refl = tvm::ffi::reflection;
908  refl::ObjectDef<ExternFuncNode>().def_ro("global_symbol", &ExternFuncNode::global_symbol);
909  }
911 };
912 
913 class ExternFunc : public BaseFunc {
914  public:
915  TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span());
916  TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span());
917 
920 };
921 
933 TVM_DLL Expr GetShapeOf(const Expr& expr);
934 
935 } // namespace relax
936 } // namespace tvm
937 
938 /* \brief Allow relax.Var as key in STL tables
939  *
940  * For most Relax expressions, it would be ambiguous whether the
941  * expression should follow reference equality or structural equality.
942  * This is not the case for variables, which do not contain nested
943  * internal structure, and are frequently used as keys in lookup
944  * tables.
945  *
946  * Providing `std::hash` and `std::equal_to` specializations for
947  * `relax::Var` allows it to be used as a key in STL tables. For
948  * `relax::Expr`, the user must specify the type of equality used
949  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
950  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
951  */
952 template <>
953 struct std::hash<tvm::relax::Var> {
954  std::size_t operator()(const tvm::relax::Var& var) const {
955  return tvm::runtime::ObjectPtrHash()(var);
956  }
957 };
958 
959 template <>
960 struct std::equal_to<tvm::relax::Var> {
961  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
962  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
963  }
964 };
965 
966 #endif // TVM_RELAX_EXPR_H_
Managed reference to BaseAttrsNode.
Definition: attrs.h:131
Base node of all functions.
Definition: function.h:139
Managed reference to BaseFuncNode.
Definition: function.h:233
Managed reference to DictAttrsNode.
Definition: attrs.h:162
Reference to PrimExprNode.
Definition: expr.h:124
Base node of all non-primitive expressions.
Definition: expr.h:416
ffi::Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:423
Managed reference to RelaxExprNode.
Definition: expr.h:439
Definition: source_map.h:111
Definition: expr.h:658
ffi::Array< Binding > bindings
Definition: expr.h:660
Span span
Definition: expr.h:661
static void RegisterReflection()
Definition: expr.h:663
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object)
Definition: expr.h:675
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BindingBlock, ObjectRef, BindingBlockNode)
BindingBlockNode * CopyOnWrite()
BindingBlock(ffi::Array< Binding > bindings, Span span=Span())
The base class of a variable binding in Relax.
Definition: expr.h:564
Var var
The return variable to bound to.
Definition: expr.h:568
static void RegisterReflection()
Definition: expr.h:570
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object)
Span span
Definition: expr.h:566
Definition: expr.h:582
const BindingNode * operator->() const
Definition: expr.h:590
Binding(ObjectPtr< BindingNode > n)
Definition: expr.h:587
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
const BindingNode * get() const
Definition: expr.h:591
Binding(ffi::UnsafeInit tag)
Definition: expr.h:588
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:141
ffi::Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is designed to be non-empty only for intrinsic...
Definition: expr.h:163
static void RegisterReflection()
Definition: expr.h:165
tvm::ffi::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:152
Expr op
The operator(function) being invoked.
Definition: expr.h:149
Attrs attrs
The additional attributes.
Definition: expr.h:155
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode)
Definition: expr.h:176
Call(Expr op, ffi::Array< Expr > args, Attrs attrs=Attrs(), ffi::Array< StructInfo > sinfo_args=ffi::Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:422
static void RegisterReflection()
Definition: expr.h:433
bool is_scalar() const
Definition: expr.h:431
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Constant", ConstantNode, LeafExprNode)
runtime::Tensor data
The data of the tensor.
Definition: expr.h:425
TensorType tensor_type() const
Definition: expr.h:440
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::Tensor data, ffi::Optional< StructInfo > struct_info_annotation=std::nullopt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:534
DataType value
The data value.
Definition: expr.h:537
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataTypeImm", DataTypeImmNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:539
Managed reference to DataTypeImm.
Definition: expr.h:550
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:683
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowBlock", DataflowBlockNode, BindingBlockNode)
static void RegisterReflection()
Definition: expr.h:685
Definition: expr.h:693
DataflowBlock(ffi::Array< Binding > bindings, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlock, BindingBlock, DataflowBlockNode)
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:393
static void RegisterReflection()
Definition: expr.h:395
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowVar", DataflowVarNode, VarNode)
Definition: expr.h:404
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVar, Var, DataflowVarNode)
DataflowVar(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:406
The extern function, which can represent packed function.
Definition: expr.h:901
static void RegisterReflection()
Definition: expr.h:906
ffi::String global_symbol
The name of global symbol.
Definition: expr.h:904
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ExternFunc", ExternFuncNode, BaseFuncNode)
Definition: expr.h:913
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
ExternFunc(ffi::String global_symbol, Span span=Span())
A Relax function.
Definition: expr.h:806
static void RegisterReflection()
Definition: expr.h:817
SeqExpr body
The body of the function.
Definition: expr.h:811
ffi::Array< Var > params
The parameters to the function.
Definition: expr.h:809
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Function", FunctionNode, BaseFuncNode)
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:813
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:815
Definition: expr.h:830
static Function CreateEmpty(ffi::Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
Function(ffi::Array< Var > params, Expr body, ffi::Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, BaseFunc, FunctionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
The unique identifier of variables.
Definition: expr.h:49
ffi::String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:56
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:64
static void RegisterReflection()
Definition: expr.h:58
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.Id", IdNode, Object)
Definition: expr.h:68
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ObjectRef, IdNode)
Id(ffi::String name_hint)
The constructor.
Condition expression.
Definition: expr.h:749
static void RegisterReflection()
Definition: expr.h:758
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.If", IfNode, ExprNode)
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:754
Expr cond
The condition.
Definition: expr.h:752
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:756
Definition: expr.h:770
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(If, Expr, IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
Base type of all (non-function) leaf Exprs.
Definition: expr.h:303
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.LeafExpr", LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:313
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:602
Expr value
The input value to match cast.
Definition: expr.h:605
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode)
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:607
static void RegisterReflection()
Definition: expr.h:609
Managed reference to MatchCastNode.
Definition: expr.h:622
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:462
static void RegisterReflection()
Definition: expr.h:467
PrimExpr value
The prim expr representing the value.
Definition: expr.h:465
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.PrimValue", PrimValueNode, LeafExprNode)
Managed reference to PrimValueNode.
Definition: expr.h:478
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimValue, LeafExpr, PrimValueNode)
PrimValue(PrimExpr value, Span span=Span())
The constructor.
A sequence of blocks followed by an expression.
Definition: expr.h:704
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode)
Expr body
Definition: expr.h:707
static void RegisterReflection()
Definition: expr.h:709
ffi::Array< BindingBlock > blocks
Definition: expr.h:706
Definition: expr.h:718
SeqExpr(ffi::Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:320
ffi::Array< PrimExpr > values
Definition: expr.h:323
static void RegisterReflection()
Definition: expr.h:325
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ShapeExpr", ShapeExprNode, LeafExprNode)
Definition: expr.h:332
ShapeExpr(ffi::Array< PrimExpr > values, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapeExpr, LeafExpr, ShapeExprNode)
Represent a string literal constant.
Definition: expr.h:502
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.StringImm", StringImmNode, LeafExprNode)
ffi::String value
The data value.
Definition: expr.h:505
static void RegisterReflection()
Definition: expr.h:507
Managed reference to StringImm.
Definition: expr.h:518
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
StringImm(ffi::String value, Span span=Span())
The constructor.
Base type of all structure information.
Definition: expr.h:108
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:114
static constexpr const uint32_t _type_child_slots
Definition: expr.h:124
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:122
static void RegisterReflection()
Definition: expr.h:116
TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object)
Managed reference to StructInfoNode.
Definition: expr.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfo, ObjectRef, StructInfoNode)
Managed reference to TensorTypeNode.
Definition: type.h:94
Get index-th field out of a tuple.
Definition: expr.h:259
static void RegisterReflection()
Definition: expr.h:266
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.TupleGetItem", TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:264
Expr tuple
The tuple Expression.
Definition: expr.h:262
Definition: expr.h:275
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:206
static void RegisterReflection()
Definition: expr.h:211
tvm::ffi::Array< Expr > fields
the fields of the tuple
Definition: expr.h:209
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Tuple", TupleNode, ExprNode)
Definition: expr.h:218
Tuple(tvm::ffi::Array< Expr > fields, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tuple, Expr, TupleNode)
Tuple(tvm::ffi::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:242
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:630
static void RegisterReflection()
Definition: expr.h:635
Expr value
The binding value.
Definition: expr.h:633
bool SEqual(const VarBindingNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.VarBinding", VarBindingNode, BindingNode)
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:651
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarBinding, Binding, VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
The variable class for all Relax bindings.
Definition: expr.h:340
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:364
const ffi::String & name_hint() const
Definition: expr.h:347
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:344
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Var", VarNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:349
bool SEqual(const VarNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
Definition: expr.h:358
Definition: expr.h:377
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Var, LeafExpr, VarNode)
Var(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Var(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:379
VarNode * CopyOnWrite()
Runtime primitive data type.
Definition: data_type.h:47
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Base expr nodes in TVM.
Function nodes.
Definition: repr_printer.h:91
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:890
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:884
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:897
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:880
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:878
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:873
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:882
Call WithFields(Call call, ffi::Optional< Expr > opt_op=ffi::Optional< Expr >(), ffi::Optional< ffi::Array< Expr >> opt_args=ffi::Optional< ffi::Array< Expr >>(), ffi::Optional< Attrs > opt_attrs=ffi::Optional< Attrs >(), ffi::Optional< ffi::Array< StructInfo >> opt_sinfo_args=ffi::Optional< ffi::Array< StructInfo >>(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
If WithFields(If if_expr, ffi::Optional< Expr > opt_cond=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_true_branch=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_false_branch=ffi::Optional< Expr >(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns if_expr with the given properties. A null property denotes 'no change'. Returns if_expr if al...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
TIR expressions.
Common operators defined for Expr.