tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_EXPR_H_
20 #define TVM_RELAX_EXPR_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/container/map.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ir/expr.h>
26 #include <tvm/ir/function.h>
27 #include <tvm/ir/source_map.h>
28 #include <tvm/node/node.h>
29 #include <tvm/relax/type.h>
30 #include <tvm/runtime/object.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/op.h>
33 
34 #include <functional>
35 
36 namespace tvm {
37 namespace relax {
38 
39 using Expr = RelaxExpr;
49 class IdNode : public Object {
50  public:
56  ffi::String name_hint;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint,
61  refl::AttachFieldFlag::SEqHashIgnore());
62  }
63 
64  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
66 };
67 
68 class Id : public ObjectRef {
69  public:
74  TVM_DLL explicit Id(ffi::String name_hint);
75 
77 };
78 
108 class StructInfoNode : public Object {
109  public:
114  mutable Span span;
115 
116  static void RegisterReflection() {
117  namespace refl = tvm::ffi::reflection;
118  refl::ObjectDef<StructInfoNode>().def_ro("span", &StructInfoNode::span,
119  refl::AttachFieldFlag::SEqHashIgnore());
120  }
121 
122  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
123 
124  static constexpr const uint32_t _type_child_slots = 7;
125  TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object);
126 };
127 
132 class StructInfo : public ObjectRef {
133  public:
135 };
136 
141 class CallNode : public ExprNode {
142  public:
150 
152  tvm::ffi::Array<Expr> args;
153 
156 
167  ffi::Array<StructInfo> sinfo_args;
168 
169  static void RegisterReflection() {
170  namespace refl = tvm::ffi::reflection;
171  refl::ObjectDef<CallNode>()
172  .def_ro("op", &CallNode::op)
173  .def_ro("args", &CallNode::args)
174  .def_ro("attrs", &CallNode::attrs)
175  .def_ro("sinfo_args", &CallNode::sinfo_args);
176  }
178 };
179 
180 class Call : public Expr {
181  public:
190  TVM_DLL Call(Expr op, ffi::Array<Expr> args, Attrs attrs = Attrs(),
191  ffi::Array<StructInfo> sinfo_args = ffi::Array<StructInfo>(), Span span = Span());
192 
195 };
196 
203  Call call, ffi::Optional<Expr> opt_op = ffi::Optional<Expr>(),
204  ffi::Optional<ffi::Array<Expr>> opt_args = ffi::Optional<ffi::Array<Expr>>(),
205  ffi::Optional<Attrs> opt_attrs = ffi::Optional<Attrs>(),
206  ffi::Optional<ffi::Array<StructInfo>> opt_sinfo_args = ffi::Optional<ffi::Array<StructInfo>>(),
207  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
208 
210 class TupleNode : public ExprNode {
211  public:
213  tvm::ffi::Array<Expr> fields;
214 
215  static void RegisterReflection() {
216  namespace refl = tvm::ffi::reflection;
217  refl::ObjectDef<TupleNode>().def_ro("fields", &TupleNode::fields);
218  }
220 };
221 
222 class Tuple : public Expr {
223  public:
229  TVM_DLL explicit Tuple(tvm::ffi::Array<Expr> fields, Span span = Span());
230 
245  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
246  TVM_DLL explicit Tuple(tvm::ffi::Array<RelaxExpr> fields, Span span = Span())
247  : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
248 
251 };
252 
259  ffi::Optional<ffi::Array<Expr>> opt_fields = ffi::Optional<ffi::Array<Expr>>(),
260  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
261 
263 class TupleGetItemNode : public ExprNode {
264  public:
268  int index;
269 
270  static void RegisterReflection() {
271  namespace refl = tvm::ffi::reflection;
272  refl::ObjectDef<TupleGetItemNode>()
273  .def_ro("tuple_value", &TupleGetItemNode::tuple)
274  .def_ro("index", &TupleGetItemNode::index);
275  }
277 };
278 
279 class TupleGetItem : public Expr {
280  public:
287  TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
288 
291 };
292 
299  ffi::Optional<Expr> opt_tuple = ffi::Optional<Expr>(),
300  ffi::Optional<Integer> opt_index = ffi::Optional<Integer>(),
301  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
302 
307 class LeafExprNode : public ExprNode {
308  public:
309  static constexpr const uint32_t _type_child_slots = 7;
311 };
312 
317 class LeafExpr : public Expr {
318  public:
320 };
321 
324 class ShapeExprNode : public LeafExprNode {
325  public:
327  ffi::Array<PrimExpr> values;
328 
329  static void RegisterReflection() {
330  namespace refl = tvm::ffi::reflection;
331  refl::ObjectDef<ShapeExprNode>().def_ro("values", &ShapeExprNode::values);
332  }
334 };
335 
336 class ShapeExpr : public LeafExpr {
337  public:
338  TVM_DLL explicit ShapeExpr(ffi::Array<PrimExpr> values, Span span = Span());
341 };
342 
344 class VarNode : public LeafExprNode {
345  public:
349 
351  const ffi::String& name_hint() const { return vid->name_hint; }
352 
353  static void RegisterReflection() {
354  namespace refl = tvm::ffi::reflection;
355  refl::ObjectDef<VarNode>().def_ro("vid", &VarNode::vid);
356  // customize structural equal and hash to include struct_info_
357  refl::TypeAttrDef<VarNode>()
358  .def("__s_equal__", &VarNode::SEqual)
359  .def("__s_hash__", &VarNode::SHash);
360  }
361 
362  bool SEqual(const VarNode* other,
363  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
364  return equal(vid, other->vid, false, "vid") &&
365  equal(struct_info_, other->struct_info_, false, "struct_info_");
366  }
367 
368  uint64_t SHash(uint64_t init_hash,
369  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const {
370  uint64_t hash_value = init_hash;
371  hash_value = hash(vid, hash_value, false);
372  hash_value = hash(struct_info_, hash_value, false);
373  return hash_value;
374  }
375 
376  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
377  static constexpr const uint32_t _type_child_slots = 1;
379 };
380 
381 class Var : public LeafExpr {
382  public:
383  TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional<StructInfo> struct_info_annotation,
384  Span span = Span())
385  : Var(Id(name_hint), struct_info_annotation, span) {}
386 
387  TVM_DLL explicit Var(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
388  Span span = Span());
390 
392 };
393 
397 class DataflowVarNode : public VarNode {
398  public:
399  static void RegisterReflection() {
400  namespace refl = tvm::ffi::reflection;
401  refl::ObjectDef<DataflowVarNode>();
402  }
403 
404  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
406 };
407 
408 class DataflowVar : public Var {
409  public:
410  TVM_DLL explicit DataflowVar(ffi::String name_hint,
411  ffi::Optional<StructInfo> struct_info_annotation, Span span = Span())
412  : DataflowVar(Id(name_hint), struct_info_annotation, span) {}
413 
414  TVM_DLL explicit DataflowVar(Id vid, ffi::Optional<StructInfo> struct_info_annotation,
415  Span span = Span());
416 
419 };
420 
426 class ConstantNode : public LeafExprNode {
427  public:
430 
433 
435  bool is_scalar() const { return data->ndim == 0; }
436 
437  static void RegisterReflection() {
438  namespace refl = tvm::ffi::reflection;
439  refl::ObjectDef<ConstantNode>().def_ro("data", &ConstantNode::data);
440  }
442 };
443 
444 class Constant : public LeafExpr {
445  public:
453  TVM_DLL explicit Constant(runtime::Tensor data,
454  ffi::Optional<StructInfo> struct_info_annotation = std::nullopt,
455  Span span = Span());
456 
459 };
460 
466 class PrimValueNode : public LeafExprNode {
467  public:
470 
471  static void RegisterReflection() {
472  namespace refl = tvm::ffi::reflection;
473  refl::ObjectDef<PrimValueNode>().def_ro("value", &PrimValueNode::value);
474  }
476 };
477 
482 class PrimValue : public LeafExpr {
483  public:
489  TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span());
490 
497  TVM_DLL static PrimValue Int64(int64_t value, Span span = Span());
498 
501 };
502 
506 class StringImmNode : public LeafExprNode {
507  public:
509  ffi::String value;
510 
511  static void RegisterReflection() {
512  namespace refl = tvm::ffi::reflection;
513  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
514  }
516 };
517 
522 class StringImm : public LeafExpr {
523  public:
529  TVM_DLL explicit StringImm(ffi::String value, Span span = Span());
530 
533 };
534 
539  public:
542 
543  static void RegisterReflection() {
544  namespace refl = tvm::ffi::reflection;
545  refl::ObjectDef<DataTypeImmNode>().def_ro("value", &DataTypeImmNode::value);
546  }
548 };
549 
554 class DataTypeImm : public LeafExpr {
555  public:
561  TVM_DLL explicit DataTypeImm(DataType value, Span span = Span());
562 
565 };
566 
568 class BindingNode : public Object {
569  public:
570  mutable Span span;
573 
574  static void RegisterReflection() {
575  namespace refl = tvm::ffi::reflection;
576  refl::ObjectDef<BindingNode>()
577  .def_ro("span", &BindingNode::span, refl::AttachFieldFlag::SEqHashIgnore())
578  .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef());
579  }
580 
581  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
582 
583  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object);
584 };
585 
586 class Binding : public ObjectRef {
587  protected:
588  Binding() = default;
589 
590  public:
591  explicit Binding(ObjectPtr<BindingNode> n) : ObjectRef(n) {}
592  explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {}
594  const BindingNode* operator->() const { return static_cast<const BindingNode*>(data_.get()); }
595  const BindingNode* get() const { return operator->(); }
597 };
598 
606 class MatchCastNode : public BindingNode {
607  public:
612 
613  static void RegisterReflection() {
614  namespace refl = tvm::ffi::reflection;
615  refl::ObjectDef<MatchCastNode>()
616  .def_ro("value", &MatchCastNode::value)
617  .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef());
618  }
620 };
621 
626 class MatchCast : public Binding {
627  public:
628  TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
629 
632 };
633 
634 class VarBindingNode : public BindingNode {
635  public:
638 
639  static void RegisterReflection() {
640  namespace refl = tvm::ffi::reflection;
641  refl::ObjectDef<VarBindingNode>().def_ro("value", &VarBindingNode::value);
642  // customize the SEqual and SHash methods for better error messages
643  refl::TypeAttrDef<VarBindingNode>()
644  .def("__s_equal__", &VarBindingNode::SEqual)
645  .def("__s_hash__", &VarBindingNode::SHash);
646  }
647 
648  bool SEqual(const VarBindingNode* other,
649  ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const;
650  uint64_t SHash(uint64_t init_hash,
651  ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) const;
653 };
654 
655 class VarBinding : public Binding {
656  public:
657  TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
660 };
661 
662 class BindingBlockNode : public Object {
663  public:
664  ffi::Array<Binding> bindings;
665  mutable Span span;
666 
667  static void RegisterReflection() {
668  namespace refl = tvm::ffi::reflection;
669  refl::ObjectDef<BindingBlockNode>()
670  .def_ro("bindings", &BindingBlockNode::bindings)
671  .def_ro("span", &BindingBlockNode::span, refl::AttachFieldFlag::SEqHashIgnore(),
672  refl::DefaultValue(Span()));
673  }
674 
675  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
676  TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object);
677 };
678 
679 class BindingBlock : public ObjectRef {
680  public:
681  TVM_DLL explicit BindingBlock(ffi::Array<Binding> bindings, Span span = Span());
683 
685 };
686 
688  public:
689  static void RegisterReflection() {
690  namespace refl = tvm::ffi::reflection;
691  refl::ObjectDef<DataflowBlockNode>();
692  }
695 };
696 
697 class DataflowBlock : public BindingBlock {
698  public:
699  TVM_DLL explicit DataflowBlock(ffi::Array<Binding> bindings, Span span = Span());
702 };
703 
708 class SeqExprNode : public ExprNode {
709  public:
710  ffi::Array<BindingBlock> blocks;
712 
713  static void RegisterReflection() {
714  namespace refl = tvm::ffi::reflection;
715  refl::ObjectDef<SeqExprNode>()
716  .def_ro("blocks", &SeqExprNode::blocks)
717  .def_ro("body", &SeqExprNode::body);
718  }
720 };
721 
722 class SeqExpr : public Expr {
723  public:
724  /* \brief Implicit conversion constructor
725  *
726  * Relax nodes that introduce a new scope (e.g. `relax::Function`)
727  * are required to be held as SeqExpr. This implicit conversion
728  * provides allows callsites to use these member variables when the
729  * C++ compile-time type is a `relax::Expr`. For example,
730  * a transform may use `func.CopyOnWrite()->body = expr;`.
731  *
732  * If the expression is already a `relax::SeqExpr`, the same
733  * underlying `relax::SeqExprNode` is used, and no copies are made.
734  */
735  TVM_DLL SeqExpr(Expr body); // NOLINT(*)
736 
737  TVM_DLL explicit SeqExpr(ffi::Array<BindingBlock> blocks, Expr body, Span span = Span());
740 };
741 
753 class IfNode : public ExprNode {
754  public:
761 
762  static void RegisterReflection() {
763  namespace refl = tvm::ffi::reflection;
764  refl::ObjectDef<IfNode>()
765  .def_ro("cond", &IfNode::cond)
766  .def_ro("true_branch", &IfNode::true_branch)
767  .def_ro("false_branch", &IfNode::false_branch);
768  }
769 
770  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
772 };
773 
774 class If : public Expr {
775  public:
793  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
794 
797 };
798 
804 If WithFields(If if_expr, ffi::Optional<Expr> opt_cond = ffi::Optional<Expr>(),
805  ffi::Optional<Expr> opt_true_branch = ffi::Optional<Expr>(),
806  ffi::Optional<Expr> opt_false_branch = ffi::Optional<Expr>(),
807  ffi::Optional<Span> opt_span = ffi::Optional<Span>());
808 
810 class FunctionNode : public BaseFuncNode {
811  public:
813  ffi::Array<Var> params;
819  bool is_pure;
820 
821  static void RegisterReflection() {
822  namespace refl = tvm::ffi::reflection;
823  refl::ObjectDef<FunctionNode>()
824  .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDef())
825  .def_ro("body", &FunctionNode::body)
826  .def_ro("ret_struct_info", &FunctionNode::ret_struct_info)
827  .def_ro("is_pure", &FunctionNode::is_pure);
828  }
829 
830  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode;
832 };
833 
834 class Function : public BaseFunc {
835  public:
857  TVM_DLL explicit Function(ffi::Array<Var> params, Expr body,
858  ffi::Optional<StructInfo> ret_struct_info, bool is_pure = true,
859  DictAttrs attrs = DictAttrs(), Span span = Span());
860 
865  TVM_DLL static Function CreateEmpty(ffi::Array<Var> params, StructInfo ret_struct_info,
866  bool is_pure = true, DictAttrs attrs = DictAttrs(),
867  Span span = Span());
868 
871 };
872 
873 // TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and
874 // kPrimitive.
875 namespace attr {
877 constexpr const char* kPrimitive = "Primitive";
882 constexpr const char* kCodegen = "Codegen";
884 constexpr const char* kComposite = "Composite";
886 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
888 constexpr const char* kWorkspaceSize = "WorkspaceSize";
889 
890 // Note: in the future, we prefer snake_case instead of CamelCase for attributes.
891 // Past ones will be kept for backwards compatibility.
894 constexpr const char* kForcePure = "relax.force_pure";
895 
901 constexpr const char* kNumInput = "num_input";
902 } // namespace attr
903 
905 class ExternFuncNode : public BaseFuncNode {
906  public:
908  ffi::String global_symbol;
909 
910  static void RegisterReflection() {
911  namespace refl = tvm::ffi::reflection;
912  refl::ObjectDef<ExternFuncNode>().def_ro("global_symbol", &ExternFuncNode::global_symbol);
913  }
915 };
916 
917 class ExternFunc : public BaseFunc {
918  public:
919  TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span());
920  TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span());
921 
924 };
925 
937 TVM_DLL Expr GetShapeOf(const Expr& expr);
938 
939 } // namespace relax
940 } // namespace tvm
941 
942 /* \brief Allow relax.Var as key in STL tables
943  *
944  * For most Relax expressions, it would be ambiguous whether the
945  * expression should follow reference equality or structural equality.
946  * This is not the case for variables, which do not contain nested
947  * internal structure, and are frequently used as keys in lookup
948  * tables.
949  *
950  * Providing `std::hash` and `std::equal_to` specializations for
951  * `relax::Var` allows it to be used as a key in STL tables. For
952  * `relax::Expr`, the user must specify the type of equality used
953  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
954  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
955  */
956 template <>
957 struct std::hash<tvm::relax::Var> {
958  std::size_t operator()(const tvm::relax::Var& var) const {
959  return tvm::runtime::ObjectPtrHash()(var);
960  }
961 };
962 
963 template <>
964 struct std::equal_to<tvm::relax::Var> {
965  bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
966  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
967  }
968 };
969 
970 #endif // TVM_RELAX_EXPR_H_
Managed reference to BaseAttrsNode.
Definition: attrs.h:131
Base node of all functions.
Definition: function.h:139
Managed reference to BaseFuncNode.
Definition: function.h:233
Managed reference to DictAttrsNode.
Definition: attrs.h:162
Reference to PrimExprNode.
Definition: expr.h:124
Base node of all non-primitive expressions.
Definition: expr.h:416
ffi::Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:423
Managed reference to RelaxExprNode.
Definition: expr.h:439
Definition: source_map.h:111
Definition: expr.h:662
ffi::Array< Binding > bindings
Definition: expr.h:664
Span span
Definition: expr.h:665
static void RegisterReflection()
Definition: expr.h:667
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object)
Definition: expr.h:679
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BindingBlock, ObjectRef, BindingBlockNode)
BindingBlockNode * CopyOnWrite()
BindingBlock(ffi::Array< Binding > bindings, Span span=Span())
The base class of a variable binding in Relax.
Definition: expr.h:568
Var var
The return variable to bound to.
Definition: expr.h:572
static void RegisterReflection()
Definition: expr.h:574
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object)
Span span
Definition: expr.h:570
Definition: expr.h:586
const BindingNode * operator->() const
Definition: expr.h:594
Binding(ObjectPtr< BindingNode > n)
Definition: expr.h:591
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding)
const BindingNode * get() const
Definition: expr.h:595
Binding(ffi::UnsafeInit tag)
Definition: expr.h:592
Call corresponds to callable invocation. Corresponds to operation in computational graph terminology.
Definition: expr.h:141
ffi::Array< StructInfo > sinfo_args
The structure info arguments of a CallNode. sinfo_args is by default designed to be non-empty only fo...
Definition: expr.h:167
static void RegisterReflection()
Definition: expr.h:169
tvm::ffi::Array< Expr > args
The arguments(inputs) of the call.
Definition: expr.h:152
Expr op
The operator(function) being invoked.
Definition: expr.h:149
Attrs attrs
The additional attributes.
Definition: expr.h:155
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode)
Definition: expr.h:180
Call(Expr op, ffi::Array< Expr > args, Attrs attrs=Attrs(), ffi::Array< StructInfo > sinfo_args=ffi::Array< StructInfo >(), Span span=Span())
The constructor.
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode)
Constant tensor.
Definition: expr.h:426
static void RegisterReflection()
Definition: expr.h:437
bool is_scalar() const
Definition: expr.h:435
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Constant", ConstantNode, LeafExprNode)
runtime::Tensor data
The data of the tensor.
Definition: expr.h:429
TensorType tensor_type() const
Definition: expr.h:444
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Constant, LeafExpr, ConstantNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode)
Constant(runtime::Tensor data, ffi::Optional< StructInfo > struct_info_annotation=std::nullopt, Span span=Span())
The constructor.
Represent a data type constant.
Definition: expr.h:538
DataType value
The data value.
Definition: expr.h:541
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataTypeImm", DataTypeImmNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:543
Managed reference to DataTypeImm.
Definition: expr.h:554
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode)
DataTypeImm(DataType value, Span span=Span())
The constructor.
Definition: expr.h:687
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowBlock", DataflowBlockNode, BindingBlockNode)
static void RegisterReflection()
Definition: expr.h:689
Definition: expr.h:697
DataflowBlock(ffi::Array< Binding > bindings, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlock, BindingBlock, DataflowBlockNode)
A sub-type of the variable node used to mark dataflow variables from normal visible "function local" ...
Definition: expr.h:397
static void RegisterReflection()
Definition: expr.h:399
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowVar", DataflowVarNode, VarNode)
Definition: expr.h:408
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode)
DataflowVar(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVar, Var, DataflowVarNode)
DataflowVar(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:410
The extern function, which can represent packed function.
Definition: expr.h:905
static void RegisterReflection()
Definition: expr.h:910
ffi::String global_symbol
The name of global symbol.
Definition: expr.h:908
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ExternFunc", ExternFuncNode, BaseFuncNode)
Definition: expr.h:917
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFunc, BaseFunc, ExternFuncNode)
ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode)
ExternFunc(ffi::String global_symbol, Span span=Span())
A Relax function.
Definition: expr.h:810
static void RegisterReflection()
Definition: expr.h:821
SeqExpr body
The body of the function.
Definition: expr.h:815
ffi::Array< Var > params
The parameters to the function.
Definition: expr.h:813
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Function", FunctionNode, BaseFuncNode)
StructInfo ret_struct_info
The return type of the function.
Definition: expr.h:817
bool is_pure
Whether the function is annotated as pure or not.
Definition: expr.h:819
Definition: expr.h:834
static Function CreateEmpty(ffi::Array< Var > params, StructInfo ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Mimics the constructor but without body Expr.
Function(ffi::Array< Var > params, Expr body, ffi::Optional< StructInfo > ret_struct_info, bool is_pure=true, DictAttrs attrs=DictAttrs(), Span span=Span())
Construct a Relax Function.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, BaseFunc, FunctionNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
The unique identifier of variables.
Definition: expr.h:49
ffi::String name_hint
The name of the variable, this only acts as a hint to the user, and is not used for equality.
Definition: expr.h:56
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:64
static void RegisterReflection()
Definition: expr.h:58
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.Id", IdNode, Object)
Definition: expr.h:68
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ObjectRef, IdNode)
Id(ffi::String name_hint)
The constructor.
Condition expression.
Definition: expr.h:753
static void RegisterReflection()
Definition: expr.h:762
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.If", IfNode, ExprNode)
SeqExpr true_branch
The expression evaluated when condition is true.
Definition: expr.h:758
Expr cond
The condition.
Definition: expr.h:756
SeqExpr false_branch
The expression evaluated when condition is false.
Definition: expr.h:760
Definition: expr.h:774
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(If, Expr, IfNode)
If(Expr cond, Expr true_branch, Expr false_branch, Span span=Span())
The constructor.
Base type of all (non-function) leaf Exprs.
Definition: expr.h:307
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.LeafExpr", LeafExprNode, ExprNode)
Managed reference to BaseExprNode.
Definition: expr.h:317
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LeafExpr, Expr, LeafExprNode)
Runtime-match the value to the struct info.
Definition: expr.h:606
Expr value
The input value to match cast.
Definition: expr.h:609
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode)
StructInfo struct_info
The struct info pattern to match to.
Definition: expr.h:611
static void RegisterReflection()
Definition: expr.h:613
Managed reference to MatchCastNode.
Definition: expr.h:626
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchCast, Binding, MatchCastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode)
MatchCast(Var var, Expr value, StructInfo struct_info, Span span=Span())
PrimValue.
Definition: expr.h:466
static void RegisterReflection()
Definition: expr.h:471
PrimExpr value
The prim expr representing the value.
Definition: expr.h:469
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.PrimValue", PrimValueNode, LeafExprNode)
Managed reference to PrimValueNode.
Definition: expr.h:482
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode)
static PrimValue Int64(int64_t value, Span span=Span())
Create a int64 prim value.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimValue, LeafExpr, PrimValueNode)
PrimValue(PrimExpr value, Span span=Span())
The constructor.
A sequence of blocks followed by an expression.
Definition: expr.h:708
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode)
Expr body
Definition: expr.h:711
static void RegisterReflection()
Definition: expr.h:713
ffi::Array< BindingBlock > blocks
Definition: expr.h:710
Definition: expr.h:722
SeqExpr(ffi::Array< BindingBlock > blocks, Expr body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqExpr, Expr, SeqExprNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode)
A shape expression which allows users to construct a shape containing PrimExpr.
Definition: expr.h:324
ffi::Array< PrimExpr > values
Definition: expr.h:327
static void RegisterReflection()
Definition: expr.h:329
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ShapeExpr", ShapeExprNode, LeafExprNode)
Definition: expr.h:336
ShapeExpr(ffi::Array< PrimExpr > values, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapeExpr, LeafExpr, ShapeExprNode)
Represent a string literal constant.
Definition: expr.h:506
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.StringImm", StringImmNode, LeafExprNode)
ffi::String value
The data value.
Definition: expr.h:509
static void RegisterReflection()
Definition: expr.h:511
Managed reference to StringImm.
Definition: expr.h:522
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, LeafExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
StringImm(ffi::String value, Span span=Span())
The constructor.
Base type of all structure information.
Definition: expr.h:108
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:114
static constexpr const uint32_t _type_child_slots
Definition: expr.h:124
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:122
static void RegisterReflection()
Definition: expr.h:116
TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object)
Managed reference to StructInfoNode.
Definition: expr.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfo, ObjectRef, StructInfoNode)
Managed reference to TensorTypeNode.
Definition: type.h:94
Get index-th field out of a tuple.
Definition: expr.h:263
static void RegisterReflection()
Definition: expr.h:270
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.TupleGetItem", TupleGetItemNode, ExprNode)
int index
which value to get
Definition: expr.h:268
Expr tuple
The tuple Expression.
Definition: expr.h:266
Definition: expr.h:279
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode)
TupleGetItem(Expr tuple, int index, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItem, Expr, TupleGetItemNode)
Tuple container.
Definition: expr.h:210
static void RegisterReflection()
Definition: expr.h:215
tvm::ffi::Array< Expr > fields
the fields of the tuple
Definition: expr.h:213
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Tuple", TupleNode, ExprNode)
Definition: expr.h:222
Tuple(tvm::ffi::Array< Expr > fields, Span span=Span())
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tuple, Expr, TupleNode)
Tuple(tvm::ffi::Array< RelaxExpr > fields, Span span=Span())
Utility constructor to handle conversion to relax::Expr.
Definition: expr.h:246
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode)
Definition: expr.h:634
static void RegisterReflection()
Definition: expr.h:639
Expr value
The binding value.
Definition: expr.h:637
bool SEqual(const VarBindingNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.VarBinding", VarBindingNode, BindingNode)
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:655
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarBinding, Binding, VarBindingNode)
VarBinding(Var var, Expr value, Span span=Span())
The variable class for all Relax bindings.
Definition: expr.h:344
uint64_t SHash(uint64_t init_hash, ffi::TypedFunction< uint64_t(AnyView, uint64_t, bool)> hash) const
Definition: expr.h:368
const ffi::String & name_hint() const
Definition: expr.h:351
Id vid
The identifier of the variable, which is used for comparing stable equality across transformations.
Definition: expr.h:348
TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Var", VarNode, LeafExprNode)
static void RegisterReflection()
Definition: expr.h:353
bool SEqual(const VarNode *other, ffi::TypedFunction< bool(AnyView, AnyView, bool, AnyView)> equal) const
Definition: expr.h:362
Definition: expr.h:381
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Var, LeafExpr, VarNode)
Var(Id vid, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Var(ffi::String name_hint, ffi::Optional< StructInfo > struct_info_annotation, Span span=Span())
Definition: expr.h:383
VarNode * CopyOnWrite()
Runtime primitive data type.
Definition: data_type.h:47
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Base expr nodes in TVM.
Function nodes.
Definition: repr_printer.h:91
constexpr const char * kForcePure
Override checking purity for this function and treat as pure (is_pure must be set to true)
Definition: expr.h:894
constexpr const char * kWorkspaceSize
The required workspace for an external function.
Definition: expr.h:888
constexpr const char * kNumInput
The number of inputs of a function. If a function has the num_input attribute, the last func->params....
Definition: expr.h:901
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: expr.h:884
constexpr const char * kCodegen
Indicate the codegen that should be used for building this function. When this is unset or set to "de...
Definition: expr.h:882
constexpr const char * kPrimitive
Mark the function as a primitive function.
Definition: expr.h:877
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: expr.h:886
Call WithFields(Call call, ffi::Optional< Expr > opt_op=ffi::Optional< Expr >(), ffi::Optional< ffi::Array< Expr >> opt_args=ffi::Optional< ffi::Array< Expr >>(), ffi::Optional< Attrs > opt_attrs=ffi::Optional< Attrs >(), ffi::Optional< ffi::Array< StructInfo >> opt_sinfo_args=ffi::Optional< ffi::Array< StructInfo >>(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns call with the given properties. A null property denotes 'no change'. Returns call if all prop...
If WithFields(If if_expr, ffi::Optional< Expr > opt_cond=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_true_branch=ffi::Optional< Expr >(), ffi::Optional< Expr > opt_false_branch=ffi::Optional< Expr >(), ffi::Optional< Span > opt_span=ffi::Optional< Span >())
Returns if_expr with the given properties. A null property denotes 'no change'. Returns if_expr if al...
Expr GetShapeOf(const Expr &expr)
Get the shape of Expr.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.
TIR expressions.
Common operators defined for Expr.