tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_STRUCT_INFO_H_
20 #define TVM_RELAX_STRUCT_INFO_H_
21 
22 #include <tvm/ffi/reflection/registry.h>
23 #include <tvm/ir/env_func.h>
24 #include <tvm/ir/source_map.h>
25 #include <tvm/node/cast.h>
27 #include <tvm/relax/expr.h>
28 #include <tvm/relax/type.h>
29 #include <tvm/runtime/object.h>
30 
31 #include <utility>
32 
33 namespace tvm {
34 namespace relax {
35 
40  public:
41  static void RegisterReflection() {
42  namespace refl = tvm::ffi::reflection;
43  refl::ObjectDef<ObjectStructInfoNode>();
44  }
46 };
47 
52 class ObjectStructInfo : public StructInfo {
53  public:
54  TVM_DLL ObjectStructInfo(Span span = Span());
55 
57 };
58 
63  public:
65  ffi::Optional<PrimExpr> value;
66 
69 
70  static void RegisterReflection() {
71  namespace refl = tvm::ffi::reflection;
72  refl::ObjectDef<PrimStructInfoNode>()
73  .def_ro("value", &PrimStructInfoNode::value)
74  .def_ro("dtype", &PrimStructInfoNode::dtype);
75  }
77 };
78 
83 class PrimStructInfo : public StructInfo {
84  public:
85  /* Construct a PrimStructInfo with a known dtype, but unknown value */
86  TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
87 
88  /* Construct a PrimStructInfo with a known value */
89  TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());
90 
92 };
93 
98  public:
100  ffi::Optional<ffi::Array<PrimExpr>> values;
105  int ndim;
106 
108  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
109 
110  static void RegisterReflection() {
111  namespace refl = tvm::ffi::reflection;
112  refl::ObjectDef<ShapeStructInfoNode>()
113  .def_ro("values", &ShapeStructInfoNode::values)
114  .def_ro("ndim", &ShapeStructInfoNode::ndim);
115  }
117 };
118 
123 class ShapeStructInfo : public StructInfo {
124  public:
130  TVM_DLL ShapeStructInfo(ffi::Array<PrimExpr> values, Span span = Span());
136  TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
137 
139 };
140 
145  public:
150  ffi::Optional<Expr> shape;
154  ffi::Optional<VDevice> vdevice;
161  int ndim;
162 
164  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
165 
167  bool IsUnknownDtype() const { return dtype.is_void(); }
168 
170  ffi::Optional<ffi::Array<PrimExpr>> GetShape() const {
171  if (!shape.defined()) return {};
172  ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
173  return shape_sinfo->values;
174  }
175 
176  static void RegisterReflection() {
177  namespace refl = tvm::ffi::reflection;
178  refl::ObjectDef<TensorStructInfoNode>()
179  .def_ro("shape", &TensorStructInfoNode::shape)
180  .def_ro("dtype", &TensorStructInfoNode::dtype)
181  .def_ro("vdevice", &TensorStructInfoNode::vdevice)
182  .def_ro("ndim", &TensorStructInfoNode::ndim);
183  }
185 };
186 
191 class TensorStructInfo : public StructInfo {
192  public:
203  ffi::Optional<VDevice> vdevice = std::nullopt, Span span = Span());
204 
212  TVM_DLL TensorStructInfo(DataType dtype, int ndim, ffi::Optional<VDevice> vdevice = std::nullopt,
213  Span span = Span());
214 
216 };
217 
222  public:
224  ffi::Array<StructInfo> fields;
225 
226  static void RegisterReflection() {
227  namespace refl = tvm::ffi::reflection;
228  refl::ObjectDef<TupleStructInfoNode>().def_ro("fields", &TupleStructInfoNode::fields);
229  }
231 };
232 
237 class TupleStructInfo : public StructInfo {
238  public:
244  TVM_DLL TupleStructInfo(ffi::Array<StructInfo> fields, Span span = Span());
245 
247 };
248 
256 
264  public:
270  ffi::Optional<ffi::Array<StructInfo>> params;
280  ffi::Optional<StructInfoDeriveFunc> derive_func;
286  bool purity;
287 
292  bool IsOpaque() const { return !params.defined(); }
293 
294  static void RegisterReflection() {
295  namespace refl = tvm::ffi::reflection;
296  refl::ObjectDef<FuncStructInfoNode>()
297  .def_ro("params", &FuncStructInfoNode::params, refl::AttachFieldFlag::SEqHashDef())
298  .def_ro("ret", &FuncStructInfoNode::ret)
299  .def_ro("derive_func", &FuncStructInfoNode::derive_func)
300  .def_ro("purity", &FuncStructInfoNode::purity);
301  }
303 };
304 
309 class FuncStructInfo : public StructInfo {
310  public:
311  explicit FuncStructInfo(ObjectPtr<FuncStructInfoNode> data) : StructInfo(ffi::UnsafeInit{}) {
312  TVM_FFI_ICHECK(data != nullptr);
313  data_ = std::move(data);
314  }
325  TVM_DLL FuncStructInfo(ffi::Array<StructInfo> params, StructInfo ret, bool purity = true,
326  Span span = Span());
327 
339  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
340  Span span = Span());
341 
353  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
354  Span span = Span());
355 
357 };
358 
366 template <typename T>
367 inline ffi::Optional<T> MatchStructInfo(const Expr& expr) {
368  using TNode = typename T::ContainerType;
369  if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
370  return ffi::GetRef<T>(ptr);
371  } else {
372  return std::nullopt;
373  }
374 }
375 
383 template <typename T>
384 inline const T* GetStructInfoAs(const Expr& expr) {
385  TVM_FFI_ICHECK(expr->struct_info_.defined())
386  << "The struct_info is not populated, check if you have normalized the expr";
387  return expr->struct_info_.as<T>();
388 }
389 
396 inline StructInfo GetStructInfo(const Expr& expr) {
397  auto* ptr = expr->struct_info_.as<StructInfoNode>();
398  TVM_FFI_ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
399  return ffi::GetRef<StructInfo>(ptr);
400 }
401 
408 inline bool HasVoidStructInfo(const Expr& expr) {
409  auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
410  return ptr != nullptr && ptr->fields.size() == 0;
411 }
412 
420 TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
421 
422 } // namespace relax
423 } // namespace tvm
424 #endif // TVM_RELAX_STRUCT_INFO_H_
The utility for constructing Relax binding blocks.
Value casting helpers.
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference to RelaxExprNode.
Definition: expr.h:441
Definition: source_map.h:111
Please refer to TypedEnvFunc<R(Args..)>.
Definition: env_func.h:104
Definition: block_builder.h:264
Definition: expr.h:180
Structure information about function.
Definition: struct_info.h:263
bool IsOpaque() const
Definition: struct_info.h:292
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FuncStructInfo", FuncStructInfoNode, StructInfoNode)
ffi::Optional< ffi::Array< StructInfo > > params
The parameter struct info of the function.
Definition: struct_info.h:270
StructInfo ret
The struct info of the function's return value.
Definition: struct_info.h:274
bool purity
Whether the function is pure.
Definition: struct_info.h:286
static void RegisterReflection()
Definition: struct_info.h:294
ffi::Optional< StructInfoDeriveFunc > derive_func
Derivation function of opaque functions that may take any number of parameters.
Definition: struct_info.h:280
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:309
FuncStructInfo(ObjectPtr< FuncStructInfoNode > data)
Definition: struct_info.h:311
FuncStructInfo(ffi::Array< StructInfo > params, StructInfo ret, bool purity=true, Span span=Span())
Constructor from parameter struct info and return value struct info.
static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity=false, Span span=Span())
Constructing an opaque function struct info using derive_func.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FuncStructInfo, StructInfo, FuncStructInfoNode)
static FuncStructInfo OpaqueFunc(StructInfo ret=ObjectStructInfo(), bool purity=false, Span span=Span())
Construct an opaque function using from return struct info.
Opaque object.
Definition: struct_info.h:39
static void RegisterReflection()
Definition: struct_info.h:41
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectStructInfo", ObjectStructInfoNode, StructInfoNode)
Managed reference to ObjectStructInfoNode.
Definition: struct_info.h:52
ObjectStructInfo(Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectStructInfo, StructInfo, ObjectStructInfoNode)
Primitive value.
Definition: struct_info.h:62
ffi::Optional< PrimExpr > value
Underlying primitive value, if known.
Definition: struct_info.h:65
DataType dtype
Underlying data type of the primitive value.
Definition: struct_info.h:68
static void RegisterReflection()
Definition: struct_info.h:70
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimStructInfo", PrimStructInfoNode, StructInfoNode)
Managed reference to PrimStructInfoNode.
Definition: struct_info.h:83
PrimStructInfo(DataType dtype, Span span=Span())
PrimStructInfo(PrimExpr value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimStructInfo, StructInfo, PrimStructInfoNode)
StructInfo of shape value.
Definition: struct_info.h:97
static void RegisterReflection()
Definition: struct_info.h:110
bool IsUnknownNdim() const
Definition: struct_info.h:108
int ndim
The number of dimension of the shape, can be unknown.
Definition: struct_info.h:105
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeStructInfo", ShapeStructInfoNode, StructInfoNode)
ffi::Optional< ffi::Array< PrimExpr > > values
optionally stores the symbolic value patterns of the shape
Definition: struct_info.h:100
Managed reference to ShapeStructInfoNode.
Definition: struct_info.h:123
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeStructInfo, StructInfo, ShapeStructInfoNode)
ShapeStructInfo(int ndim, Span span=Span())
Construction with known unknown symbolic shape patterns.
ShapeStructInfo(ffi::Array< PrimExpr > values, Span span=Span())
Construction with known symbolic shape patterns.
Base type of all structure information.
Definition: expr.h:108
Managed reference to StructInfoNode.
Definition: expr.h:132
StructInfo of Tensor.
Definition: struct_info.h:144
ffi::Optional< ffi::Array< PrimExpr > > GetShape() const
Definition: struct_info.h:170
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TensorStructInfo", TensorStructInfoNode, StructInfoNode)
DataType dtype
The content data type, use void to denote the dtype is unknown.
Definition: struct_info.h:156
ffi::Optional< VDevice > vdevice
The virtual device, indicates where the tensor is expected to be executed.
Definition: struct_info.h:154
ffi::Optional< Expr > shape
optionally store the shape expression of the tensor.
Definition: struct_info.h:150
static void RegisterReflection()
Definition: struct_info.h:176
int ndim
The number of dimension of the tensor, can be unknown.
Definition: struct_info.h:161
bool IsUnknownDtype() const
Definition: struct_info.h:167
bool IsUnknownNdim() const
Definition: struct_info.h:164
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:191
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorStructInfo, StructInfo, TensorStructInfoNode)
TensorStructInfo(Expr shape, DataType dtype, ffi::Optional< VDevice > vdevice=std::nullopt, Span span=Span())
Construction with a known shape expression.
TensorStructInfo(DataType dtype, int ndim, ffi::Optional< VDevice > vdevice=std::nullopt, Span span=Span())
Construction with an unknown shape expression.
StructInfo of Tuple.
Definition: struct_info.h:221
ffi::Array< StructInfo > fields
The struct info of tuple fields.
Definition: struct_info.h:224
static void RegisterReflection()
Definition: struct_info.h:226
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TupleStructInfo", TupleStructInfoNode, StructInfoNode)
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:237
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TupleStructInfo, StructInfo, TupleStructInfoNode)
TupleStructInfo(ffi::Array< StructInfo > fields, Span span=Span())
Constructor.
Runtime primitive data type.
Definition: data_type.h:47
bool is_void() const
Definition: data_type.h:213
Serializable global function used in IR.
Definition: repr_printer.h:91
const T * GetStructInfoAs(const Expr &expr)
Get the structure info of a given expr and try to cast it as const T*.
Definition: struct_info.h:384
void UpdateStructInfo(Expr expr, StructInfo struct_info)
Update the struct info of an Expr.
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:396
bool HasVoidStructInfo(const Expr &expr)
Whether the expr has void struct info.
Definition: struct_info.h:408
ffi::Optional< T > MatchStructInfo(const Expr &expr)
Match and check if expr have StructInfo T and return it.
Definition: struct_info.h:367
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
A managed object in the TVM runtime.
Relax Types.
A map from source names to source code.