tvm
type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TYPE_H_
25 #define TVM_RELAX_TYPE_H_
26 
27 #include <tvm/ir/attrs.h>
28 #include <tvm/ir/env_func.h>
29 #include <tvm/ir/tensor_type.h>
30 #include <tvm/ir/type.h>
31 #include <tvm/ir/type_relation.h>
32 #include <tvm/runtime/registry.h>
33 #include <tvm/tir/expr.h>
34 
35 #include <string>
36 
37 namespace tvm {
38 namespace relax {
39 
41 static constexpr int kUnknownNDim = -1;
42 
43 class ShapeTypeNode : public TypeNode {
44  public:
46  int ndim;
47 
49  v->Visit("ndim", &ndim);
50  v->Visit("span", &span);
51  }
52 
53  bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
54  return equal(ndim, other->ndim);
55  }
56 
57  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); }
58 
59  static constexpr const char* _type_key = "relax.ShapeType";
61 };
62 
63 class ShapeType : public Type {
64  public:
65  // TODO(relax-team): remove the default value later.
66  TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span());
67 
69 };
70 
71 class ObjectTypeNode : public TypeNode {
72  public:
73  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
74 
75  bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; }
76 
77  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
78 
79  static constexpr const char* _type_key = "relax.ObjectType";
81 };
82 
83 class ObjectType : public Type {
84  public:
85  TVM_DLL ObjectType(Span span = Span());
86 
88 };
89 
91  public:
96  int ndim;
99 
101  v->Visit("ndim", &ndim);
102  v->Visit("dtype", &dtype);
103  v->Visit("span", &span);
104  }
105 
106  bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const {
107  return equal(ndim, other->ndim) && equal(dtype, other->dtype);
108  }
109 
110  void SHashReduce(SHashReducer hash_reduce) const {
111  hash_reduce(ndim);
112  hash_reduce(dtype);
113  }
114 
115  inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
116 
117  inline bool IsUnknownDtype() const { return dtype.is_void(); }
118 
119  static constexpr const char* _type_key = "relax.DynTensorType";
121 };
122 
127 class DynTensorType : public Type {
128  public:
135  TVM_DLL DynTensorType(int ndim, DataType dtype, Span span = Span());
136 
140  TVM_DLL static DynTensorType CreateUnknownNDim(DataType dtype, Span span = Span());
141 
143 };
144 
145 class PackedFuncTypeNode : public TypeNode {
146  public:
147  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
148 
149  bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; }
150 
151  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
152 
153  static constexpr const char* _type_key = "relax.PackedFuncType";
155 };
156 
157 class PackedFuncType : public Type {
158  public:
159  TVM_DLL PackedFuncType(Span span = Span());
160 
162 };
163 
164 } // namespace relax
165 } // namespace tvm
166 #endif // TVM_RELAX_TYPE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Base of all Tensor types This container can hold TensorType or GenericTensorType.
Definition: tensor_type.h:36
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Definition: source_map.h:120
Type is the base type of all types.
Definition: type.h:74
Span span
Span that points to the original source code. Reserved debug information.
Definition: type.h:80
Managed reference to TypeNode.
Definition: type.h:93
Definition: type.h:90
int ndim
The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of dimensions.
Definition: type.h:96
bool SEqualReduce(const DynTensorTypeNode *other, SEqualReducer equal) const
Definition: type.h:106
DataType dtype
The content data type, use void to denote the dtype is unknown.
Definition: type.h:98
bool IsUnknownDtype() const
Definition: type.h:117
void SHashReduce(SHashReducer hash_reduce) const
Definition: type.h:110
void VisitAttrs(tvm::AttrVisitor *v)
Definition: type.h:100
bool IsUnknownNdim() const
Definition: type.h:115
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode)
static constexpr const char * _type_key
Definition: type.h:119
Managed reference to DynTensorTypeNode.
Definition: type.h:127
static DynTensorType CreateUnknownNDim(DataType dtype, Span span=Span())
Create a DynTensorType with unknown ndim.
DynTensorType(int ndim, DataType dtype, Span span=Span())
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode)
Definition: type.h:71
bool SEqualReduce(const ObjectTypeNode *other, SEqualReducer equal) const
Definition: type.h:75
void SHashReduce(SHashReducer hash_reduce) const
Definition: type.h:77
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: type.h:73
static constexpr const char * _type_key
Definition: type.h:79
Definition: type.h:83
ObjectType(Span span=Span())
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode)
Definition: type.h:145
void SHashReduce(SHashReducer hash_reduce) const
Definition: type.h:151
bool SEqualReduce(const PackedFuncTypeNode *other, SEqualReducer equal) const
Definition: type.h:149
static constexpr const char * _type_key
Definition: type.h:153
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: type.h:147
Definition: type.h:157
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode)
PackedFuncType(Span span=Span())
Definition: type.h:43
int ndim
size of the shape.
Definition: type.h:46
void SHashReduce(SHashReducer hash_reduce) const
Definition: type.h:57
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode)
static constexpr const char * _type_key
Definition: type.h:59
void VisitAttrs(tvm::AttrVisitor *v)
Definition: type.h:48
bool SEqualReduce(const ShapeTypeNode *other, SEqualReducer equal) const
Definition: type.h:53
Definition: type.h:63
ShapeType(int ndim=kUnknownNDim, Span span=Span())
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode)
Runtime primitive data type.
Definition: data_type.h:43
bool is_void() const
Definition: data_type.h:156
Serializable global function used in IR.
Helpers for attribute objects.
IR/AST nodes for the unified type system in TVM.
tvm::Span Span
Definition: base.h:65
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
This file defines the TVM global function registry.
Polymorphic tensor types.
TIR expressions.
Type relation and function for type inference(checking).