tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/container/map.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ir/expr.h>
26 #include <tvm/ir/function.h>
27 #include <tvm/ir/source_map.h>
28 #include <tvm/node/node.h>
29 #include <tvm/relax/type.h>
30 #include <tvm/runtime/object.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/op.h>
33 
34 #include <functional>
35 
36 namespace tvm {
37 namespace relax {
38 
39 using Expr = RelaxExpr;
49 class IdNode : public Object {
50  public:
56  ffi::String name_hint;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint,
61  refl::AttachFieldFlag::SEqHashIgnore());
62  }
63 
64  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
66 };
67 
68 class Id : public ObjectRef {
69  public:
74  TVM_DLL explicit Id(ffi::String name_hint);
75 
77 };
78 
108 class StructInfoNode : public Object {
109  public:
114  mutable Span span;
115 
116  static void RegisterReflection() {
117  namespace refl = tvm::ffi::reflection;
118  refl::ObjectDef<StructInfoNode>().def_ro("span", &StructInfoNode::span,
119  refl::AttachFieldFlag::SEqHashIgnore());
120  }
121 
122  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
123 
124  static constexpr const uint32_t _type_child_slots = 7;
125  TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object);
126 };
127 
132 class StructInfo : public ObjectRef {
133  public:
135 };
136 
141 class CallNode : public ExprNode {
142  public:
150 
152  tvm::ffi::Array<Expr> args;
153 
156 
167  ffi::Array<StructInfo> sinfo_args;
168 
169  static void RegisterReflection() {
170  namespace refl = tvm::ffi::reflection;
171  refl::ObjectDef<CallNode>()
172  .def_ro("op", &CallNode::op)
173  .def_ro("args", &CallNode::args)
174  .def_ro("attrs", &CallNode::attrs)
175  .def_ro("sinfo_args", &CallNode::sinfo_args);
176  }
178 };
179 
180 class Call : public Expr {
181  public:
190  TVM_DLL Call(Expr op, ffi::Array<Expr> args, Attrs attrs = Attrs(),
191  ffi::Array<StructInfo> sinfo_args = ffi::Array<StructInfo>(), Span span = Span());
192 
195 };
196 
203  Call call, ffi::Optional<Expr> opt_op = ffi::Optional<Expr>(),
204  ffi::Optional<ffi::Array<Expr>> opt_args = ffi::Optional<ffi::Array<Expr>>(),
205  ffi::Optional<Attrs> opt_attrs = ffi::Optional<Attrs>(),
206  ffi::Optional<ffi::Array<StructInfo>> opt_sinfo_args = ffi::Optional<ffi::Array<StructInfo>>(),
207  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
208 
210 class TupleNode : public ExprNode {
211  public:
213  tvm::ffi::Array<Expr> fields;
214 
215  static void RegisterReflection() {
216  namespace refl = tvm::ffi::reflection;
217  refl::ObjectDef<TupleNode>().def_ro("fields", &TupleNode::fields);
218  }
220 };
221 
222 class Tuple : public Expr {
223  public:
229  TVM_DLL explicit Tuple(tvm::ffi::Array<Expr> fields, Span span = Span());
230 
245  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
246  TVM_DLL explicit Tuple(tvm::ffi::Array<RelaxExpr> fields, Span span = Span())
247  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
248 
251 };
252 
259  ffi::Optional<ffi::Array<Expr>> opt_fields = ffi::Optional<ffi::Array<Expr>>(),
260  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
261 
263 class TupleGetItemNode : public ExprNode {
264  public:
268  int index;
269 
270  static void RegisterReflection() {
271  namespace refl = tvm::ffi::reflection;
272  refl::ObjectDef<TupleGetItemNode>()
273  .def_ro("tuple_value", &TupleGetItemNode::tuple)
274  .def_ro("index", &TupleGetItemNode::index);
275  }
277 };
278 
279 class TupleGetItem : public Expr {
280  public:
287  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
288 
291 };
292 
299  ffi::Optional<Expr> opt_tuple = ffi::Optional<Expr>(),
300  ffi::Optional<Integer> opt_index = ffi::Optional<Integer>(),
301  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
302 
307 class LeafExprNode : public ExprNode {
308  public:
309  static constexpr const uint32_t _type_child_slots = 7;
311 };
312 
317 class LeafExpr : public Expr {
318  public:
320 };
321 
324 class ShapeExprNode : public LeafExprNode {
325  public:
327  ffi::Array<PrimExpr> values;
328 
329  static void RegisterReflection() {
330  namespace refl = tvm::ffi::reflection;
331  refl::ObjectDef<ShapeExprNode>().def_ro("values", &ShapeExprNode::values);
332  }
334 };
335 
336 class ShapeExpr : public LeafExpr {
337  public:
338  TVM_DLL explicit ShapeExpr(ffi::Array<PrimExpr> values, Span span = Span());
341 };
342 
344 class VarNode : public LeafExprNode {
345  public:
349 
351  const ffi::String& name_hint() const { return vid->name_hint; }
352 
353  static void RegisterReflection() {
354  namespace refl = tvm::ffi::reflection;
355  refl::ObjectDef<VarNode>().def_ro("vid", &VarNode::vid);
356  // customize structural equal and hash to include struct_info_
357  refl::TypeAttrDef<VarNode>()
358  .def("__s_equal__", &VarNode::SEqual)
359  .def("__s_hash__", &VarNode::SHash);
360  }
361 
362  bool SEqual(const VarNode* other,
363  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
364  return equal(vid, other->vid, false, "vid") &&
365  equal(struct_info_, other->struct_info_, false, "struct_info_");
366  }
367 
368  int64_t SHash(int64_t init_hash, ffi::TypedFunction<int64_t(AnyView, int64_t, bool)> hash) const {
369  int64_t hash_value = init_hash;
370  hash_value = hash(vid, hash_value, false);
371  hash_value = hash(struct_info_, hash_value, false);
372  return hash_value;
373  }
374 
375  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
376  static constexpr const uint32_t _type_child_slots = 1;
378 };
379 
380 class Var : public LeafExpr {
381  public:
382  TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional<StructInfo> struct_info_annotation,
383  Span span = Span())
384  : Var(Id(name_hint), struct_info_annotation, span) {}
385 
386  TVM_DLL explicit Var(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
387  Span span = Span());
389 
391 };
392 
396 class DataflowVarNode : public VarNode {
397  public:
398  static void RegisterReflection() {
399  namespace refl = tvm::ffi::reflection;
400  refl::ObjectDef<DataflowVarNode>();
401  }
402 
403  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
405 };
406 
407 class DataflowVar : public Var {
408  public:
409  TVM_DLL explicit DataflowVar(ffi::String name_hint,
410  ffi::Optional<StructInfo> struct_info_annotation, Span span = Span())
411  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
412 
413  TVM_DLL explicit DataflowVar(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
414  Span span = Span());
415 
418 };
419 
425 class ConstantNode : public LeafExprNode {
426  public:
429 
432 
434  bool is_scalar() const { return data->ndim == 0; }
435 
436  static void RegisterReflection() {
437  namespace refl = tvm::ffi::reflection;
438  refl::ObjectDef<ConstantNode>().def_ro("data", &ConstantNode::data);
439  }
441 };
442 
443 class Constant : public LeafExpr {
444  public:
452  TVM_DLL explicit Constant(runtime::Tensor data,
453  ffi::Optional<StructInfo> struct_info_annotation = std::nullopt,
454  Span span = Span());
455 
458 };
459 
465 class PrimValueNode : public LeafExprNode {
466  public:
469 
470  static void RegisterReflection() {
471  namespace refl = tvm::ffi::reflection;
472  refl::ObjectDef<PrimValueNode>().def_ro("value", &PrimValueNode::value);
473  }
475 };
476 
481 class PrimValue : public LeafExpr {
482  public:
488  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
489 
496  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
497 
500 };
501 
505 class StringImmNode : public LeafExprNode {
506  public:
508  ffi::String value;
509 
510  static void RegisterReflection() {
511  namespace refl = tvm::ffi::reflection;
512  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
513  }
515 };
516 
521 class StringImm : public LeafExpr {
522  public:
528  TVM_DLL explicit StringImm(ffi::String value, Span span = Span());
529 
532 };
533 
538  public:
541 
542  static void RegisterReflection() {
543  namespace refl = tvm::ffi::reflection;
544  refl::ObjectDef<DataTypeImmNode>().def_ro("value", &DataTypeImmNode::value);
545  }
547 };
548 
553 class DataTypeImm : public LeafExpr {
554  public:
560  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
561 
564 };
565 
567 class BindingNode : public Object {
568  public:
569  mutable Span span;
572 
573  static void RegisterReflection() {
574  namespace refl = tvm::ffi::reflection;
575  refl::ObjectDef<BindingNode>()
576  .def_ro("span", &BindingNode::span, refl::AttachFieldFlag::SEqHashIgnore())
577  .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef());
578  }
579 
580  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
581 
582  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object);
583 };
584 
585 class Binding : public ObjectRef {
586  protected:
587  Binding() = default;
588 
589  public:
590  explicit Binding(ObjectPtr<BindingNode> n) : ObjectRef(n) {}
591  explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {}
593  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
594  const BindingNode* get() const { return operator->(); }
596 };
597 
605 class MatchCastNode : public BindingNode {
606  public:
611 
612  static void RegisterReflection() {
613  namespace refl = tvm::ffi::reflection;
614  refl::ObjectDef<MatchCastNode>()
615  .def_ro("value", &MatchCastNode::value)
616  .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef());
617  }
619 };
620 
625 class MatchCast : public Binding {
626  public:
627  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
628 
631 };
632 
633 class VarBindingNode : public BindingNode {
634  public:
637 
638  static void RegisterReflection() {
639  namespace refl = tvm::ffi::reflection;
640  refl::ObjectDef<VarBindingNode>().def_ro("value", &VarBindingNode::value);
641  // customize the SEqual and SHash methods for better error messages
642  refl::TypeAttrDef<VarBindingNode>()
643  .def("__s_equal__", &VarBindingNode::SEqual)
644  .def("__s_hash__", &VarBindingNode::SHash);
645  }
646 
647  bool SEqual(const VarBindingNode* other,
648  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const;
649  int64_t SHash(int64_t init_hash, ffi::TypedFunction<int64_t(AnyView, int64_t, bool)> hash) const;
651 };
652 
653 class VarBinding : public Binding {
654  public:
655  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
658 };
659 
660 class BindingBlockNode : public Object {
661  public:
662  ffi::Array<Binding> bindings;
663  mutable Span span;
664 
665  static void RegisterReflection() {
666  namespace refl = tvm::ffi::reflection;
667  refl::ObjectDef<BindingBlockNode>()
668  .def_ro("bindings", &BindingBlockNode::bindings)
669  .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(),
670  refl::DefaultValue(Span()));
671  }
672 
673  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
674  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object);
675 };
676 
677 class BindingBlock : public ObjectRef {
678  public:
679  TVM_DLL explicit BindingBlock(ffi::Array<Binding> bindings, Span span = Span());
681 
683 };
684 
686  public:
687  static void RegisterReflection() {
688  namespace refl = tvm::ffi::reflection;
689  refl::ObjectDef<DataflowBlockNode>();
690  }
693 };
694 
695 class DataflowBlock : public BindingBlock {
696  public:
697  TVM_DLL explicit DataflowBlock(ffi::Array<Binding> bindings, Span span = Span());
700 };
701 
706 class SeqExprNode : public ExprNode {
707  public:
708  ffi::Array<BindingBlock> blocks;
710 
711  static void RegisterReflection() {
712  namespace refl = tvm::ffi::reflection;
713  refl::ObjectDef<SeqExprNode>()
714  .def_ro("blocks", &SeqExprNode::blocks)
715  .def_ro("body", &SeqExprNode::body);
716  }
718 };
719 
720 class SeqExpr : public Expr {
721  public:
722  /* \brief Implicit conversion constructor
723  *
724  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
725  * are required to be held as SeqExpr. This implicit conversion
726  * provides allows callsites to use these member variables when the
727  * C++ compile-time type is a `relax::Expr`. For example,
728  * a transform may use `func.CopyOnWrite()->body = expr;`.
729  *
730  * If the expression is already a `relax::SeqExpr`, the same
731  * underlying `relax::SeqExprNode` is used, and no copies are made.
732  */
733  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
734 
735  TVM_DLL explicit SeqExpr(ffi::Array<BindingBlock> blocks, Expr body, Span span = Span());
738 };
739 
751 class IfNode : public ExprNode {
752  public:
759 
760  static void RegisterReflection() {
761  namespace refl = tvm::ffi::reflection;
762  refl::ObjectDef<IfNode>()
763  .def_ro("cond", &IfNode::cond)
764  .def_ro("true_branch", &IfNode::true_branch)
765  .def_ro("false_branch", &IfNode::false_branch);
766  }
767 
768  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
770 };
771 
772 class If : public Expr {
773  public:
791  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
792 
795 };
796 
802 If WithFields(If if_expr, ffi::Optional<Expr> opt_cond = ffi::Optional<Expr>(),
803  ffi::Optional<Expr> opt_true_branch = ffi::Optional<Expr>(),
804  ffi::Optional<Expr> opt_false_branch = ffi::Optional<Expr>(),
805  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
806 
808 class FunctionNode : public BaseFuncNode {
809  public:
811  ffi::Array<Var> params;
817  bool is_pure;
818 
819  static void RegisterReflection() {
820  namespace refl = tvm::ffi::reflection;
821  refl::ObjectDef<FunctionNode>()
822  .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDef())
823  .def_ro("body", &FunctionNode::body)
824  .def_ro("ret_struct_info", &FunctionNode::ret_struct_info)
825  .def_ro("is_pure", &FunctionNode::is_pure);
826  }
827 
828  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
830 };
831 
832 class Function : public BaseFunc {
833  public:
855  TVM_DLL explicit Function(ffi::Array<Var> params, Expr body,
856  ffi::Optional<StructInfo> ret_struct_info, bool is_pure = true,
857  DictAttrs attrs = DictAttrs(), Span span = Span());
858 
863  TVM_DLL static Function CreateEmpty(ffi::Array<Var> params, StructInfo ret_struct_info,
864  bool is_pure = true, DictAttrs attrs = DictAttrs(),
865  Span span = Span());
866 
869 };
870 
871 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
872 // kPrimitive.
873 namespace attr {
875 constexpr const char* kPrimitive = "Primitive";
880 constexpr const char* kCodegen = "Codegen";
882 constexpr const char* kComposite = "Composite";
884 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
886 constexpr const char* kWorkspaceSize = "WorkspaceSize";
887 
888 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
889 // Past ones will be kept for backwards compatibility.
892 constexpr const char* kForcePure = "relax.force_pure";
893 
899 constexpr const char* kNumInput = "num_input";
900 } // namespace attr
901 
903 class ExternFuncNode : public BaseFuncNode {
904  public:
906  ffi::String global_symbol;
907 
908  static void RegisterReflection() {
909  namespace refl = tvm::ffi::reflection;
910  refl::ObjectDef<ExternFuncNode>().def_ro("global_symbol", &ExternFuncNode::global_symbol);
911  }
913 };
914 
915 class ExternFunc : public BaseFunc {
916  public:
917  TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span());
918  TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span());
919 
922 };
923 
935 TVM_DLL Expr GetShapeOf(const Expr& expr);
936 
937 } // namespace relax
938 } // namespace tvm
939 
940 /* \brief Allow relax.Var as key in STL tables
941  *
942  * For most Relax expressions, it would be ambiguous whether the
943  * expression should follow reference equality or structural equality.
944  * This is not the case for variables, which do not contain nested
945  * internal structure, and are frequently used as keys in lookup
946  * tables.
947  *
948  * Providing `std::hash` and `std::equal_to` specializations for
949  * `relax::Var` allows it to be used as a key in STL tables. For
950  * `relax::Expr`, the user must specify the type of equality used
951  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
952  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
953  */
954 template <>
955 struct std::hash<tvm::relax::Var> {
956  std::size_t operator()(const tvm::relax::Var& var) const {
957  return tvm::runtime::ObjectPtrHash()(var);
958  }
959 };
960 
961 template <>
962 struct std::equal_to<tvm::relax::Var> {
963  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
964  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
965  }
966 };
967 
968 #endif // TVM_RELAX_EXPR_H_
Managed reference to BaseAttrsNode.
Definition: attrs.h:131
Base node of all functions.
Definition: function.h:139
Managed reference to BaseFuncNode.
Definition: function.h:233
Managed reference to DictAttrsNode.
Definition: attrs.h:162
Reference to PrimExprNode.
Definition: expr.h:124
Base node of all non-primitive expressions.
Definition: expr.h:416
ffi::Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:423
Managed reference to RelaxExprNode.
Definition: expr.h:439
Definition: source_map.h:111
Definition: expr.h:660
ffi::Array< Binding > bindings
Definition: expr.h:662
Span span
Definition: expr.h:663
static void RegisterReflection()
Definition: expr.h:665
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object)
Definition: expr.h:677
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BindingBlock, ObjectRef, BindingBlockNode)
BindingBlockNode * CopyOnWrite()
BindingBlock(ffi::Array< Binding > bindings, Span span=Span())
The base class of a variable binding in Relax.
Definition: expr.h:567
Var var
The return variable to bound to.
Definition: expr.h:571
static void RegisterReflection()
Definition: expr.h:573
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object)
Span span
Definition: expr.h:569
Definition: expr.h:585
const BindingNode * operator->() const
Definition: expr.h:593
Binding(ObjectPtr< BindingNode > n)
Definition: expr.h:590
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
const BindingNode * get() const
Definition: expr.h:594
Binding(ffi::UnsafeInit tag)
Definition: expr.h:591
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:141
ffi::Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is by default designed to be non-empty only fo...
Definition: expr.h:167
static void RegisterReflection()
Definition: expr.h:169
tvm::ffi::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:152
Expr op
The operator(function) being invoked.
Definition: expr.h:149
Attrs attrs
The additional attributes.
Definition: expr.h:155
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode)
Definition: expr.h:180
Call(Expr op, ffi::Array< Expr > args, Attrs attrs=Attrs(), ffi::Array< StructInfo > sinfo_args=ffi::Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:425
static void RegisterReflection()
Definition: expr.h:436
bool is_scalar() const
Definition: expr.h:434
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Constant", ConstantNode, LeafExprNode)
runtime::Tensor data
The data of the tensor.
Definition: expr.h:428
TensorType tensor_type() const
Definition: expr.h:443
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::Tensor data, ffi::Optional< StructInfo > struct_info_annotation=std::nullopt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:537
DataType value
The data value.
Definition: expr.h:540
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataTypeImm", DataTypeImmNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:542
Managed reference to DataTypeImm.
Definition: expr.h:553
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:685
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowBlock", DataflowBlockNode, BindingBlockNode)
static void RegisterReflection()
Definition: expr.h:687
Definition: expr.h:695
DataflowBlock(ffi::Array< Binding > bindings, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlock, BindingBlock, DataflowBlockNode)
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:396
static void RegisterReflection()
Definition: expr.h:398
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowVar", DataflowVarNode, VarNode)
Definition: expr.h:407
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVar, Var, DataflowVarNode)
DataflowVar(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:409
The extern function, which can represent packed function.
Definition: expr.h:903
static void RegisterReflection()
Definition: expr.h:908
ffi::String global_symbol
The name of global symbol.
Definition: expr.h:906
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ExternFunc", ExternFuncNode, BaseFuncNode)
Definition: expr.h:915
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
ExternFunc(ffi::String global_symbol, Span span=Span())
A Relax function.
Definition: expr.h:808
static void RegisterReflection()
Definition: expr.h:819
SeqExpr body
The body of the function.
Definition: expr.h:813
ffi::Array< Var > params
The parameters to the function.
Definition: expr.h:811
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Function", FunctionNode, BaseFuncNode)
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:815
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:817
Definition: expr.h:832
static Function CreateEmpty(ffi::Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
Function(ffi::Array< Var > params, Expr body, ffi::Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, BaseFunc, FunctionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
The unique identifier of variables.
Definition: expr.h:49
ffi::String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:56
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:64
static void RegisterReflection()
Definition: expr.h:58
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.Id", IdNode, Object)
Definition: expr.h:68
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ObjectRef, IdNode)
Id(ffi::String name_hint)
The constructor.
Condition expression.
Definition: expr.h:751
static void RegisterReflection()
Definition: expr.h:760
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.If", IfNode, ExprNode)
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:756
Expr cond
The condition.
Definition: expr.h:754
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:758
Definition: expr.h:772
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(If, Expr, IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
Base type of all (non-function) leaf Exprs.
Definition: expr.h:307
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.LeafExpr", LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:317
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:605
Expr value
The input value to match cast.
Definition: expr.h:608
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode)
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:610
static void RegisterReflection()
Definition: expr.h:612
Managed reference to MatchCastNode.
Definition: expr.h:625
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:465
static void RegisterReflection()
Definition: expr.h:470
PrimExpr value
The prim expr representing the value.
Definition: expr.h:468
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.PrimValue", PrimValueNode, LeafExprNode)
Managed reference to PrimValueNode.
Definition: expr.h:481
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimValue, LeafExpr, PrimValueNode)
PrimValue(PrimExpr value, Span span=Span())
The constructor.
A sequence of blocks followed by an expression.
Definition: expr.h:706
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode)
Expr body
Definition: expr.h:709
static void RegisterReflection()
Definition: expr.h:711
ffi::Array< BindingBlock > blocks
Definition: expr.h:708
Definition: expr.h:720
SeqExpr(ffi::Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:324
ffi::Array< PrimExpr > values
Definition: expr.h:327
static void RegisterReflection()
Definition: expr.h:329
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ShapeExpr", ShapeExprNode, LeafExprNode)
Definition: expr.h:336
ShapeExpr(ffi::Array< PrimExpr > values, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapeExpr, LeafExpr, ShapeExprNode)
Represent a string literal constant.
Definition: expr.h:505
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.StringImm", StringImmNode, LeafExprNode)
ffi::String value
The data value.
Definition: expr.h:508
static void RegisterReflection()
Definition: expr.h:510
Managed reference to StringImm.
Definition: expr.h:521
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
StringImm(ffi::String value, Span span=Span())
The constructor.
Base type of all structure information.
Definition: expr.h:108
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:114
static constexpr const uint32_t _type_child_slots
Definition: expr.h:124
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:122
static void RegisterReflection()
Definition: expr.h:116
TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object)
Managed reference to StructInfoNode.
Definition: expr.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfo, ObjectRef, StructInfoNode)
Managed reference to TensorTypeNode.
Definition: type.h:94
Get index-th field out of a tuple.
Definition: expr.h:263
static void RegisterReflection()
Definition: expr.h:270
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.TupleGetItem", TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:268
Expr tuple
The tuple Expression.
Definition: expr.h:266
Definition: expr.h:279
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:210
static void RegisterReflection()
Definition: expr.h:215
tvm::ffi::Array< Expr > fields
the fields of the tuple
Definition: expr.h:213
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Tuple", TupleNode, ExprNode)
Definition: expr.h:222
Tuple(tvm::ffi::Array< Expr > fields, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tuple, Expr, TupleNode)
Tuple(tvm::ffi::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:246
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:633
static void RegisterReflection()
Definition: expr.h:638
Expr value
The binding value.
Definition: expr.h:636
int64_t SHash(int64_t init_hash, ffi::TypedFunction< int64_t(AnyView, int64_t, bool)> hash) const
bool SEqual(const VarBindingNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.VarBinding", VarBindingNode, BindingNode)
Definition: expr.h:653
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarBinding, Binding, VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
The variable class for all Relax bindings.
Definition: expr.h:344
int64_t SHash(int64_t init_hash, ffi::TypedFunction< int64_t(AnyView, int64_t, bool)> hash) const
Definition: expr.h:368
const ffi::String & name_hint() const
Definition: expr.h:351
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:348
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Var", VarNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:353
bool SEqual(const VarNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
Definition: expr.h:362
Definition: expr.h:380
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Var, LeafExpr, VarNode)
Var(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Var(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:382
VarNode * CopyOnWrite()
Runtime primitive data type.
Definition: data_type.h:47
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Base expr nodes in TVM.
Function nodes.
Definition: repr_printer.h:91
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:892
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:886
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:899
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:882
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:880
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:875
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:884
Call WithFields(Call call, ffi::Optional< Expr > opt_op=ffi::Optional< Expr >(), ffi::Optional< ffi::Array< Expr >> opt_args=ffi::Optional< ffi::Array< Expr >>(), ffi::Optional< Attrs > opt_attrs=ffi::Optional< Attrs >(), ffi::Optional< ffi::Array< StructInfo >> opt_sinfo_args=ffi::Optional< ffi::Array< StructInfo >>(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
If WithFields(If if_expr, ffi::Optional< Expr > opt_cond=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_true_branch=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_false_branch=ffi::Optional< Expr >(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns if_expr with the given properties. A null property denotes 'no change'. Returns if_expr if al...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
TIR expressions.
Common operators defined for Expr.