24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
27 #include <tvm/ffi/container/shape.h>
28 #include <tvm/ffi/dtype.h>
30 #include <tvm/runtime/logging.h>
34 #include <type_traits>
83 explicit DataType(DLDataType dtype) : data_(dtype) {}
92 data_.code =
static_cast<uint8_t
>(
code);
93 data_.bits =
static_cast<uint8_t
>(
bits);
95 TVM_FFI_ICHECK(
lanes > 1) <<
"Invalid value for vscale factor" <<
lanes;
97 data_.lanes = is_scalable ?
static_cast<uint16_t
>(-
lanes) :
static_cast<uint16_t
>(
lanes);
99 TVM_FFI_ICHECK_EQ(
bits, 16);
104 TVM_FFI_ICHECK_EQ(
bits, 8);
107 TVM_FFI_ICHECK_EQ(
bits, 6);
110 TVM_FFI_ICHECK_EQ(
bits, 4);
114 int code()
const {
return static_cast<int>(data_.code); }
116 int bits()
const {
return static_cast<int>(data_.bits); }
121 int lanes_as_int =
static_cast<int16_t
>(data_.lanes);
122 if (lanes_as_int < 0) {
123 TVM_FFI_THROW(InternalError)
124 <<
"Can't fetch the lanes of a scalable vector at a compile time.";
130 int lanes_as_int =
static_cast<int16_t
>(data_.lanes);
131 if (lanes_as_int >= -1) {
132 TVM_FFI_THROW(InternalError) <<
"A fixed length vector doesn't have a vscale factor.";
134 return -lanes_as_int;
152 return bits() == 8 &&
160 return bits() == 6 &&
201 int encoded_lanes =
static_cast<int16_t
>(data_.lanes);
202 return (encoded_lanes < -1) || (1 < encoded_lanes);
257 return data_.code == other.data_.code && data_.bits == other.data_.bits &&
258 data_.lanes == other.data_.lanes;
270 operator DLDataType()
const {
return data_; }
407 if (std::is_signed<tvm_index_t>::value) {
424 int data_bits = dtype.
bits() * dtype.
lanes();
431 TVM_FFI_ICHECK_EQ(data_bits % 8, 0U) <<
"Need to load/store by multiple of bytes";
432 return data_bits / 8;
442 inline bool TypeMatch(DLDataType t,
int code,
int bits,
int lanes = 1) {
443 return t.code == code && t.bits == bits && t.lanes == lanes;
451 return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
454 using ffi::DLDataTypeToString;
455 using ffi::StringToDLDataType;
458 return os << dtype.operator DLDataType();
468 struct TypeTraits<runtime::
DataType> :
public TypeTraitsBase {
469 static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType;
473 result->v_uint64 = 0;
474 result->zero_padding = 0;
475 result->type_index = TypeIndex::kTVMFFIDataType;
476 result->v_dtype = src;
481 result->v_uint64 = 0;
482 result->zero_padding = 0;
483 result->type_index = TypeIndex::kTVMFFIDataType;
484 result->v_dtype = src;
488 auto opt_dtype = TypeTraits<DLDataType>::TryCastFromAnyView(src);
496 return TypeTraits<DLDataType>::CheckAnyStrict(src);
500 return runtime::DataType(TypeTraits<DLDataType>::CopyFromAnyViewAfterCheck(src));
503 TVM_FFI_INLINE
static std::string
TypeStr() {
return ffi::StaticTypeKey::kTVMFFIDataType; }
506 return R
"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})";
516 inline int cantor_pairing_function(
int a,
int b)
const {
return (a + b) * (a + b + 1) / 2 + b; }
518 int a = dtype.
code();
519 int b = dtype.
bits();
520 int c = dtype.
lanes();
521 int d = cantor_pairing_function(a, b);
522 return cantor_pairing_function(c, d);
Runtime primitive data type.
Definition: data_type.h:47
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:406
bool is_handle() const
Definition: data_type.h:198
static DataType Float8E4M3FNUZ(int lanes=1)
Construct float8 e4m3fnuz datatype.
Definition: data_type.h:338
bool is_uint() const
Definition: data_type.h:196
int get_lanes_or_vscale_factor() const
Definition: data_type.h:137
bool is_float8_e5m2() const
Definition: data_type.h:178
static DataType Float4E2M1FN(int lanes=1)
Construct float4 e2m1fn datatype.
Definition: data_type.h:380
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
bool is_float4_e2m1fn() const
Definition: data_type.h:188
bool is_float6() const
Definition: data_type.h:159
static DataType Float8E8M0FNU(int lanes=1)
Construct float8 e8m0fnu datatype.
Definition: data_type.h:359
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:240
bool is_scalable_vector() const
Definition: data_type.h:207
bool is_float8_e4m3() const
Definition: data_type.h:168
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:244
static DataType Float8E4M3FN(int lanes=1)
Construct float8 e4m3fn datatype.
Definition: data_type.h:331
int bytes() const
Definition: data_type.h:118
TypeCode
Type code for the DataType.
Definition: data_type.h:57
@ kFloat8_e4m3b11fnuz
Definition: data_type.h:66
@ kBool
Definition: data_type.h:63
@ kHandle
Definition: data_type.h:61
@ kUInt
Definition: data_type.h:59
@ kFloat8_e4m3
Definition: data_type.h:65
@ kFloat6_e3m2fn
Definition: data_type.h:73
@ kFloat6_e2m3fn
Definition: data_type.h:72
@ kBFloat
Definition: data_type.h:62
@ kFloat
Definition: data_type.h:60
@ kFloat8_e4m3fn
Definition: data_type.h:67
@ kFloat8_e4m3fnuz
Definition: data_type.h:68
@ kCustomBegin
Definition: data_type.h:75
@ kFloat8_e8m0fnu
Definition: data_type.h:71
@ kFloat4_e2m1fn
Definition: data_type.h:74
@ kInt
Definition: data_type.h:58
@ kFloat8_e3m4
Definition: data_type.h:64
@ kFloat8_e5m2fnuz
Definition: data_type.h:70
@ kFloat8_e5m2
Definition: data_type.h:69
bool is_float6_e3m2fn() const
Definition: data_type.h:186
static DataType Float8E5M2FNUZ(int lanes=1)
Construct float8 e5m2fnuz datatype.
Definition: data_type.h:352
bool is_float8_e3m4() const
Definition: data_type.h:166
bool is_bool() const
Definition: data_type.h:143
static DataType Float8E4M3B11FNUZ(int lanes=1)
Construct float8 e4m3b11fnuz datatype.
Definition: data_type.h:322
bool is_int() const
Definition: data_type.h:194
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:235
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:265
DataType()
default constructor
Definition: data_type.h:78
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:200
int code() const
Definition: data_type.h:114
int lanes() const
Definition: data_type.h:120
static DataType Float8E5M2(int lanes=1)
Construct float8 e5m2 datatype.
Definition: data_type.h:345
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:91
bool is_float8_e8m0fnu() const
Definition: data_type.h:182
static DataType Float6E2M3FN(int lanes=1)
Construct float6 e2m3fn datatype.
Definition: data_type.h:366
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:256
bool is_float16() const
Definition: data_type.h:190
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:387
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:221
static DataType Float6E3M2FN(int lanes=1)
Construct float6 e3m2fn datatype.
Definition: data_type.h:373
int vscale_factor() const
Definition: data_type.h:129
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:83
bool is_float8_e4m3b11fnuz() const
Definition: data_type.h:170
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:227
bool is_float8_e4m3fnuz() const
Definition: data_type.h:176
bool is_fixed_length_vector() const
Definition: data_type.h:205
bool is_predicate_dtype() const
Definition: data_type.h:145
bool is_bfloat() const
Definition: data_type.h:149
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:302
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
static DataType Void()
Construct a Void type.
Definition: data_type.h:401
bool is_vector() const
Definition: data_type.h:209
bool is_float6_e2m3fn() const
Definition: data_type.h:184
bool is_scalar() const
Definition: data_type.h:141
int bits() const
Definition: data_type.h:116
bool is_float8() const
Definition: data_type.h:151
bool is_bfloat16() const
Definition: data_type.h:192
bool is_float8_e4m3fn() const
Definition: data_type.h:174
bool is_vector_bool() const
Definition: data_type.h:211
bool is_float4() const
Definition: data_type.h:164
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:396
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:286
static DataType Float8E4M3(int lanes=1)
Construct float8 e4m3 datatype.
Definition: data_type.h:315
static DataType Float8E3M4(int lanes=1)
Construct float8 e3m4 datatype.
Definition: data_type.h:308
bool is_void() const
Definition: data_type.h:213
bool is_float8_e5m2fnuz() const
Definition: data_type.h:180
bool is_float() const
Definition: data_type.h:147
std::ostream & operator<<(std::ostream &os, const DataType &dtype)
Definition: data_type.h:457
ffi::Shape::index_type tvm_index_t
Definition: data_type.h:39
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:423
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:442
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:450
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
runtime::DataType DataType
Definition: data_type.h:462
static TVM_FFI_INLINE void MoveToAny(runtime::DataType src, TVMFFIAny *result)
Definition: data_type.h:479
static TVM_FFI_INLINE void CopyToAnyView(const runtime::DataType &src, TVMFFIAny *result)
Definition: data_type.h:471
static TVM_FFI_INLINE std::string TypeSchema()
Definition: data_type.h:505
static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny *src)
Definition: data_type.h:495
static TVM_FFI_INLINE runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: data_type.h:499
static TVM_FFI_INLINE std::string TypeStr()
Definition: data_type.h:503
static TVM_FFI_INLINE std::optional< runtime::DataType > TryCastFromAnyView(const TVMFFIAny *src)
Definition: data_type.h:487