tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_RELAX_STRUCT_INFO_H_
20 #define TVM_RELAX_STRUCT_INFO_H_
21 
22 #include <tvm/ir/env_func.h>
23 #include <tvm/ir/source_map.h>
24 #include <tvm/node/node.h>
26 #include <tvm/relax/expr.h>
27 #include <tvm/relax/type.h>
28 
29 namespace tvm {
30 namespace relax {
31 
36  public:
37  void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
38 
39  bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; }
40 
41  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
42 
43  static constexpr const char* _type_key = "relax.ObjectStructInfo";
45 };
46 
51 class ObjectStructInfo : public StructInfo {
52  public:
53  TVM_DLL ObjectStructInfo(Span span = Span());
54 
56 };
57 
62  public:
65 
68 
70  v->Visit("value", &value);
71  v->Visit("dtype", &dtype);
72  v->Visit("span", &span);
73  }
74 
75  bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const {
76  return equal(value, other->value) && equal(dtype, other->dtype);
77  }
78 
79  void SHashReduce(SHashReducer hash_reduce) const {
80  hash_reduce(value);
81  hash_reduce(dtype);
82  }
83 
84  static constexpr const char* _type_key = "relax.PrimStructInfo";
86 };
87 
92 class PrimStructInfo : public StructInfo {
93  public:
94  /* Construct a PrimStructInfo with a known dtype, but unknown value */
95  TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
96 
97  /* Construct a PrimStructInfo with a known value */
98  TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());
99 
101 };
102 
107  public:
114  int ndim;
115 
117  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
118 
120  v->Visit("values", &values);
121  v->Visit("ndim", &ndim);
122  v->Visit("span", &span);
123  }
124 
126  return equal(values, other->values) && equal(ndim, other->ndim);
127  }
128 
129  void SHashReduce(SHashReducer hash_reduce) const {
130  hash_reduce(values);
131  hash_reduce(ndim);
132  }
133 
134  static constexpr const char* _type_key = "relax.ShapeStructInfo";
136 };
137 
142 class ShapeStructInfo : public StructInfo {
143  public:
149  TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span());
155  TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
156 
158 };
159 
164  public:
180  int ndim;
181 
183  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
184 
186  bool IsUnknownDtype() const { return dtype.is_void(); }
187 
190  if (!shape.defined()) return {};
191  ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
192  return shape_sinfo->values;
193  }
194 
196  v->Visit("shape", &shape);
197  v->Visit("dtype", &dtype);
198  v->Visit("vdevice", &vdevice);
199  v->Visit("ndim", &ndim);
200  v->Visit("span", &span);
201  }
202 
204  return equal(shape, other->shape) && equal(ndim, other->ndim) &&
205  equal(vdevice, other->vdevice) && equal(dtype, other->dtype);
206  }
207 
208  void SHashReduce(SHashReducer hash_reduce) const {
209  hash_reduce(shape);
210  hash_reduce(dtype);
211  hash_reduce(vdevice);
212  hash_reduce(ndim);
213  }
214 
215  static constexpr const char* _type_key = "relax.TensorStructInfo";
217 };
218 
223 class TensorStructInfo : public StructInfo {
224  public:
235  Span span = Span());
236 
244  TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional<VDevice> vdevice = NullOpt,
245  Span span = Span());
246 
248 };
249 
254  public:
257 
259  v->Visit("fields", &fields);
260  v->Visit("span", &span);
261  }
262 
264  return equal(fields, other->fields);
265  }
266 
267  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
268 
269  static constexpr const char* _type_key = "relax.TupleStructInfo";
271 };
272 
277 class TupleStructInfo : public StructInfo {
278  public:
284  TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span());
285 
287 };
288 
296 
304  public:
326  bool purity;
327 
332  bool IsOpaque() const { return !params.defined(); }
333 
335  v->Visit("params", &params);
336  v->Visit("ret", &ret);
337  v->Visit("derive_func", &derive_func);
338  v->Visit("span", &span);
339  v->Visit("purity", &purity);
340  }
341 
343  return equal.DefEqual(params, other->params) && equal(ret, other->ret) &&
344  equal(purity, other->purity) && equal(derive_func, other->derive_func);
345  }
346 
347  void SHashReduce(SHashReducer hash_reduce) const {
348  hash_reduce.DefHash(params);
349  hash_reduce(ret);
350  hash_reduce(purity);
351  hash_reduce(derive_func);
352  }
353 
354  static constexpr const char* _type_key = "relax.FuncStructInfo";
356 };
357 
362 class FuncStructInfo : public StructInfo {
363  public:
374  TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
375  Span span = Span());
376 
388  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
389  Span span = Span());
390 
402  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
403  Span span = Span());
404 
406 };
407 
415 template <typename T>
416 inline Optional<T> MatchStructInfo(const Expr& expr) {
417  using TNode = typename T::ContainerType;
418  if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
419  return GetRef<T>(ptr);
420  } else {
421  return NullOpt;
422  }
423 }
424 
432 template <typename T>
433 inline const T* GetStructInfoAs(const Expr& expr) {
434  ICHECK(expr->struct_info_.defined())
435  << "The struct_info is not populated, check if you have normalized the expr";
436  return expr->struct_info_.as<T>();
437 }
438 
445 inline StructInfo GetStructInfo(const Expr& expr) {
446  auto* ptr = expr->struct_info_.as<StructInfoNode>();
447  ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
448  return GetRef<StructInfo>(ptr);
449 }
450 
457 inline bool HasVoidStructInfo(const Expr& expr) {
458  auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
459  return ptr != nullptr && ptr->fields.size() == 0;
460 }
461 
469 TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
470 
471 } // namespace relax
472 } // namespace tvm
473 #endif // TVM_RELAX_STRUCT_INFO_H_
The utility for constructing Relax binding blocks.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelayExprNode.
Definition: expr.h:442
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Please refer to TypedEnvFunc<R(Args..)>.
Definition: env_func.h:104
Definition: block_builder.h:264
Definition: expr.h:190
Structure information about function.
Definition: struct_info.h:303
TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode)
bool IsOpaque() const
Definition: struct_info.h:332
static constexpr const char * _type_key
Definition: struct_info.h:354
StructInfo ret
The struct info of the function's return value.
Definition: struct_info.h:314
Optional< StructInfoDeriveFunc > derive_func
Derivation function of opaque functions that may take any number of parameters.
Definition: struct_info.h:320
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:334
bool purity
Whether the function is pure.
Definition: struct_info.h:326
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:347
bool SEqualReduce(const FuncStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:342
Optional< Array< StructInfo > > params
The parameter struct info of the function.
Definition: struct_info.h:310
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:362
FuncStructInfo(Array< StructInfo > params, StructInfo ret, bool purity=true, Span span=Span())
Constructor from parameter struct info and return value struct info.
static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity=false, Span span=Span())
Constructing an opaque function struct info using derive_func.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode)
static FuncStructInfo OpaqueFunc(StructInfo ret=ObjectStructInfo(), bool purity=false, Span span=Span())
Construct an opaque function using from return struct info.
Opaque object.
Definition: struct_info.h:35
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:41
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode)
static constexpr const char * _type_key
Definition: struct_info.h:43
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:37
bool SEqualReduce(const ObjectStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:39
Managed reference to ObjectStructInfoNode.
Definition: struct_info.h:51
ObjectStructInfo(Span span=Span())
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode)
Primitive value.
Definition: struct_info.h:61
Optional< PrimExpr > value
Underlying primitive value, if known.
Definition: struct_info.h:64
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:69
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:79
bool SEqualReduce(const PrimStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:75
static constexpr const char * _type_key
Definition: struct_info.h:84
DataType dtype
Underlying data type of the primitive value.
Definition: struct_info.h:67
TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode)
Managed reference to PrimStructInfoNode.
Definition: struct_info.h:92
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode)
PrimStructInfo(DataType dtype, Span span=Span())
PrimStructInfo(PrimExpr value, Span span=Span())
StructInfo of shape value.
Definition: struct_info.h:106
Optional< Array< PrimExpr > > values
optionally stores the symbolic value patterns of the shape
Definition: struct_info.h:109
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:129
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:119
bool IsUnknownNdim() const
Definition: struct_info.h:117
int ndim
The number of dimension of the shape, can be unknown.
Definition: struct_info.h:114
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode)
static constexpr const char * _type_key
Definition: struct_info.h:134
bool SEqualReduce(const ShapeStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:125
Managed reference to ShapeStructInfoNode.
Definition: struct_info.h:142
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode)
ShapeStructInfo(Array< PrimExpr > values, Span span=Span())
Construction with known symbolic shape patterns.
ShapeStructInfo(int ndim, Span span=Span())
Construction with known unknown symbolic shape patterns.
Base type of all structure information.
Definition: expr.h:110
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:116
Managed reference to StructInfoNode.
Definition: expr.h:129
StructInfo of Tensor.
Definition: struct_info.h:163
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:195
Optional< VDevice > vdevice
The virtual device, indicates where the tensor is expected to be executed.
Definition: struct_info.h:173
TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode)
Optional< Array< PrimExpr > > GetShape() const
Definition: struct_info.h:189
DataType dtype
The content data type, use void to denote the dtype is unknown.
Definition: struct_info.h:175
static constexpr const char * _type_key
Definition: struct_info.h:215
bool SEqualReduce(const TensorStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:203
int ndim
The number of dimension of the tensor, can be unknown.
Definition: struct_info.h:180
bool IsUnknownDtype() const
Definition: struct_info.h:186
Optional< Expr > shape
optionally store the shape expression of the tensor.
Definition: struct_info.h:169
bool IsUnknownNdim() const
Definition: struct_info.h:183
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:208
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:223
TensorStructInfo(DataType dtype, int ndim, Optional< VDevice > vdevice=NullOpt, Span span=Span())
Construction with an unknown shape expression.
TensorStructInfo(Expr shape, DataType dtype, Optional< VDevice > vdevice=NullOpt, Span span=Span())
Construction with a known shape expression.
TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode)
StructInfo of Tuple.
Definition: struct_info.h:253
static constexpr const char * _type_key
Definition: struct_info.h:269
TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode)
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:258
Array< StructInfo > fields
The struct info of tuple fields.
Definition: struct_info.h:256
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:267
bool SEqualReduce(const TupleStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:263
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:277
TupleStructInfo(Array< StructInfo > fields, Span span=Span())
Constructor.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
bool is_void() const
Definition: data_type.h:156
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
Serializable global function used in IR.
const T * GetStructInfoAs(const Expr &expr)
Get the structure info of a given expr and try to cast it as const T*.
Definition: struct_info.h:433
Optional< T > MatchStructInfo(const Expr &expr)
Match and check if expr have StructInfo T and return it.
Definition: struct_info.h:416
void UpdateStructInfo(Expr expr, StructInfo struct_info)
Update the struct info of an Expr.
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:445
bool HasVoidStructInfo(const Expr &expr)
Whether the expr has void struct info.
Definition: struct_info.h:457
tvm::Span Span
Definition: base.h:65
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Definitions and helper macros for IR/AST nodes.
Relax Types.
A map from source names to source code.