tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_STRUCT_INFO_H_
20 #define TVM_RELAX_STRUCT_INFO_H_
21 
22 #include <tvm/ffi/reflection/registry.h>
23 #include <tvm/ir/env_func.h>
24 #include <tvm/ir/source_map.h>
25 #include <tvm/node/node.h>
27 #include <tvm/relax/expr.h>
28 #include <tvm/relax/type.h>
29 
30 namespace tvm {
31 namespace relax {
32 
37  public:
38  static void RegisterReflection() {
39  namespace refl = tvm::ffi::reflection;
40  refl::ObjectDef<ObjectStructInfoNode>();
41  }
42 
43  static constexpr const char* _type_key = "relax.ObjectStructInfo";
45 };
46 
51 class ObjectStructInfo : public StructInfo {
52  public:
53  TVM_DLL ObjectStructInfo(Span span = Span());
54 
56 };
57 
62  public:
64  Optional<PrimExpr> value;
65 
68 
69  static void RegisterReflection() {
70  namespace refl = tvm::ffi::reflection;
71  refl::ObjectDef<PrimStructInfoNode>()
72  .def_ro("value", &PrimStructInfoNode::value)
73  .def_ro("dtype", &PrimStructInfoNode::dtype);
74  }
75 
76  static constexpr const char* _type_key = "relax.PrimStructInfo";
78 };
79 
84 class PrimStructInfo : public StructInfo {
85  public:
86  /* Construct a PrimStructInfo with a known dtype, but unknown value */
87  TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
88 
89  /* Construct a PrimStructInfo with a known value */
90  TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());
91 
93 };
94 
99  public:
101  Optional<Array<PrimExpr>> values;
106  int ndim;
107 
109  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
110 
111  static void RegisterReflection() {
112  namespace refl = tvm::ffi::reflection;
113  refl::ObjectDef<ShapeStructInfoNode>()
114  .def_ro("values", &ShapeStructInfoNode::values)
115  .def_ro("ndim", &ShapeStructInfoNode::ndim);
116  }
117 
118  static constexpr const char* _type_key = "relax.ShapeStructInfo";
120 };
121 
126 class ShapeStructInfo : public StructInfo {
127  public:
133  TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span());
139  TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
140 
142 };
143 
148  public:
153  Optional<Expr> shape;
157  Optional<VDevice> vdevice;
164  int ndim;
165 
167  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
168 
170  bool IsUnknownDtype() const { return dtype.is_void(); }
171 
173  Optional<Array<PrimExpr>> GetShape() const {
174  if (!shape.defined()) return {};
175  ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
176  return shape_sinfo->values;
177  }
178 
179  static void RegisterReflection() {
180  namespace refl = tvm::ffi::reflection;
181  refl::ObjectDef<TensorStructInfoNode>()
182  .def_ro("shape", &TensorStructInfoNode::shape)
183  .def_ro("dtype", &TensorStructInfoNode::dtype)
184  .def_ro("vdevice", &TensorStructInfoNode::vdevice)
185  .def_ro("ndim", &TensorStructInfoNode::ndim);
186  }
187 
188  static constexpr const char* _type_key = "relax.TensorStructInfo";
190 };
191 
196 class TensorStructInfo : public StructInfo {
197  public:
207  TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Optional<VDevice> vdevice = std::nullopt,
208  Span span = Span());
209 
217  TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional<VDevice> vdevice = std::nullopt,
218  Span span = Span());
219 
221 };
222 
227  public:
229  Array<StructInfo> fields;
230 
231  static void RegisterReflection() {
232  namespace refl = tvm::ffi::reflection;
233  refl::ObjectDef<TupleStructInfoNode>().def_ro("fields", &TupleStructInfoNode::fields);
234  }
235 
236  static constexpr const char* _type_key = "relax.TupleStructInfo";
238 };
239 
244 class TupleStructInfo : public StructInfo {
245  public:
251  TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span());
252 
254 };
255 
263 
271  public:
277  Optional<Array<StructInfo>> params;
287  Optional<StructInfoDeriveFunc> derive_func;
293  bool purity;
294 
299  bool IsOpaque() const { return !params.defined(); }
300 
301  static void RegisterReflection() {
302  namespace refl = tvm::ffi::reflection;
303  refl::ObjectDef<FuncStructInfoNode>()
304  .def_ro("params", &FuncStructInfoNode::params, refl::AttachFieldFlag::SEqHashDef())
305  .def_ro("ret", &FuncStructInfoNode::ret)
306  .def_ro("derive_func", &FuncStructInfoNode::derive_func)
307  .def_ro("purity", &FuncStructInfoNode::purity);
308  }
309 
310  static constexpr const char* _type_key = "relax.FuncStructInfo";
312 };
313 
318 class FuncStructInfo : public StructInfo {
319  public:
330  TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
331  Span span = Span());
332 
344  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
345  Span span = Span());
346 
358  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
359  Span span = Span());
360 
362 };
363 
371 template <typename T>
372 inline Optional<T> MatchStructInfo(const Expr& expr) {
373  using TNode = typename T::ContainerType;
374  if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
375  return GetRef<T>(ptr);
376  } else {
377  return std::nullopt;
378  }
379 }
380 
388 template <typename T>
389 inline const T* GetStructInfoAs(const Expr& expr) {
390  ICHECK(expr->struct_info_.defined())
391  << "The struct_info is not populated, check if you have normalized the expr";
392  return expr->struct_info_.as<T>();
393 }
394 
401 inline StructInfo GetStructInfo(const Expr& expr) {
402  auto* ptr = expr->struct_info_.as<StructInfoNode>();
403  ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
404  return GetRef<StructInfo>(ptr);
405 }
406 
413 inline bool HasVoidStructInfo(const Expr& expr) {
414  auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
415  return ptr != nullptr && ptr->fields.size() == 0;
416 }
417 
425 TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
426 
427 } // namespace relax
428 } // namespace tvm
429 #endif // TVM_RELAX_STRUCT_INFO_H_
The utility for constructing Relax binding blocks.
Reference to PrimExprNode.
Definition: expr.h:129
Managed reference to RelaxExprNode.
Definition: expr.h:446
Definition: source_map.h:113
Please refer to TypedEnvFunc<R(Args..)>.
Definition: env_func.h:102
Definition: block_builder.h:264
Definition: expr.h:181
Structure information about function.
Definition: struct_info.h:270
TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode)
bool IsOpaque() const
Definition: struct_info.h:299
static constexpr const char * _type_key
Definition: struct_info.h:310
StructInfo ret
The struct info of the function's return value.
Definition: struct_info.h:281
Optional< StructInfoDeriveFunc > derive_func
Derivation function of opaque functions that may take any number of parameters.
Definition: struct_info.h:287
bool purity
Whether the function is pure.
Definition: struct_info.h:293
static void RegisterReflection()
Definition: struct_info.h:301
Optional< Array< StructInfo > > params
The parameter struct info of the function.
Definition: struct_info.h:277
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:318
FuncStructInfo(Array< StructInfo > params, StructInfo ret, bool purity=true, Span span=Span())
Constructor from parameter struct info and return value struct info.
static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity=false, Span span=Span())
Constructing an opaque function struct info using derive_func.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode)
static FuncStructInfo OpaqueFunc(StructInfo ret=ObjectStructInfo(), bool purity=false, Span span=Span())
Construct an opaque function using from return struct info.
Opaque object.
Definition: struct_info.h:36
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode)
static constexpr const char * _type_key
Definition: struct_info.h:43
static void RegisterReflection()
Definition: struct_info.h:38
Managed reference to ObjectStructInfoNode.
Definition: struct_info.h:51
ObjectStructInfo(Span span=Span())
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode)
Primitive value.
Definition: struct_info.h:61
Optional< PrimExpr > value
Underlying primitive value, if known.
Definition: struct_info.h:64
static constexpr const char * _type_key
Definition: struct_info.h:76
DataType dtype
Underlying data type of the primitive value.
Definition: struct_info.h:67
TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode)
static void RegisterReflection()
Definition: struct_info.h:69
Managed reference to PrimStructInfoNode.
Definition: struct_info.h:84
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode)
PrimStructInfo(DataType dtype, Span span=Span())
PrimStructInfo(PrimExpr value, Span span=Span())
StructInfo of shape value.
Definition: struct_info.h:98
static void RegisterReflection()
Definition: struct_info.h:111
Optional< Array< PrimExpr > > values
optionally stores the symbolic value patterns of the shape
Definition: struct_info.h:101
bool IsUnknownNdim() const
Definition: struct_info.h:109
int ndim
The number of dimension of the shape, can be unknown.
Definition: struct_info.h:106
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode)
static constexpr const char * _type_key
Definition: struct_info.h:118
Managed reference to ShapeStructInfoNode.
Definition: struct_info.h:126
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode)
ShapeStructInfo(Array< PrimExpr > values, Span span=Span())
Construction with known symbolic shape patterns.
ShapeStructInfo(int ndim, Span span=Span())
Construction with known unknown symbolic shape patterns.
Base type of all structure information.
Definition: expr.h:110
Managed reference to StructInfoNode.
Definition: expr.h:135
StructInfo of Tensor.
Definition: struct_info.h:147
Optional< VDevice > vdevice
The virtual device, indicates where the tensor is expected to be executed.
Definition: struct_info.h:157
TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode)
Optional< Array< PrimExpr > > GetShape() const
Definition: struct_info.h:173
DataType dtype
The content data type, use void to denote the dtype is unknown.
Definition: struct_info.h:159
static constexpr const char * _type_key
Definition: struct_info.h:188
static void RegisterReflection()
Definition: struct_info.h:179
int ndim
The number of dimension of the tensor, can be unknown.
Definition: struct_info.h:164
bool IsUnknownDtype() const
Definition: struct_info.h:170
Optional< Expr > shape
optionally store the shape expression of the tensor.
Definition: struct_info.h:153
bool IsUnknownNdim() const
Definition: struct_info.h:167
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:196
TensorStructInfo(Expr shape, DataType dtype, Optional< VDevice > vdevice=std::nullopt, Span span=Span())
Construction with a known shape expression.
TensorStructInfo(DataType dtype, int ndim, Optional< VDevice > vdevice=std::nullopt, Span span=Span())
Construction with an unknown shape expression.
TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode)
StructInfo of Tuple.
Definition: struct_info.h:226
static constexpr const char * _type_key
Definition: struct_info.h:236
static void RegisterReflection()
Definition: struct_info.h:231
TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode)
Array< StructInfo > fields
The struct info of tuple fields.
Definition: struct_info.h:229
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:244
TupleStructInfo(Array< StructInfo > fields, Span span=Span())
Constructor.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode)
Runtime primitive data type.
Definition: data_type.h:47
bool is_void() const
Definition: data_type.h:209
Serializable global function used in IR.
Definition: repr_printer.h:91
const T * GetStructInfoAs(const Expr &expr)
Get the structure info of a given expr and try to cast it as const T*.
Definition: struct_info.h:389
Optional< T > MatchStructInfo(const Expr &expr)
Match and check if expr have StructInfo T and return it.
Definition: struct_info.h:372
void UpdateStructInfo(Expr expr, StructInfo struct_info)
Update the struct info of an Expr.
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:401
bool HasVoidStructInfo(const Expr &expr)
Whether the expr has void struct info.
Definition: struct_info.h:413
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1945
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Definitions and helper macros for IR/AST nodes.
Relax Types.
A map from source names to source code.