24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
28 #include <tvm/runtime/logging.h>
32 #include <type_traits>
70 explicit DataType(DLDataType dtype) : data_(dtype) {}
79 data_.code =
static_cast<uint8_t
>(
code);
80 data_.bits =
static_cast<uint8_t
>(
bits);
82 ICHECK(
lanes > 1) <<
"Invalid value for vscale factor" <<
lanes;
84 data_.lanes = is_scalable ?
static_cast<uint16_t
>(-
lanes) :
static_cast<uint16_t
>(
lanes);
96 int code()
const {
return static_cast<int>(data_.code); }
98 int bits()
const {
return static_cast<int>(data_.bits); }
103 int lanes_as_int =
static_cast<int16_t
>(data_.lanes);
104 if (lanes_as_int < 0) {
105 LOG(FATAL) <<
"Can't fetch the lanes of a scalable vector at a compile time.";
111 int lanes_as_int =
static_cast<int16_t
>(data_.lanes);
112 if (lanes_as_int >= -1) {
113 LOG(FATAL) <<
"A fixed length vector doesn't have a vscale factor.";
115 return -lanes_as_int;
152 int encoded_lanes =
static_cast<int16_t
>(data_.lanes);
153 return (encoded_lanes < -1) || (1 < encoded_lanes);
206 return data_.code == other.data_.code && data_.bits == other.data_.bits &&
207 data_.lanes == other.data_.lanes;
219 operator DLDataType()
const {
return data_; }
296 if (std::is_signed<tvm_index_t>::value) {
313 int data_bits = dtype.
bits() * dtype.
lanes();
319 ICHECK_EQ(data_bits % 8, 0U) <<
"Need to load/store by multiple of bytes";
320 return data_bits / 8;
330 inline bool TypeMatch(DLDataType t,
int code,
int bits,
int lanes = 1) {
331 return t.code == code && t.bits == bits && t.lanes == lanes;
339 return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
387 switch (
static_cast<int>(type_code)) {
399 return "float8_e4m3fn";
401 return "float8_e5m2";
403 return "float4_e2m1fn";
405 LOG(FATAL) <<
"unknown type_code=" <<
static_cast<int>(type_code);
410 inline std::ostream&
operator<<(std::ostream& os, DLDataType t) {
411 if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
419 os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
424 int16_t lanes =
static_cast<int16_t
>(t.lanes);
427 os << static_cast<int>(t.bits);
431 }
else if (lanes < -1) {
432 os <<
"xvscalex" << -lanes;
438 return os << dtype.operator DLDataType();
442 if (t.bits == 0)
return "";
443 std::ostringstream os;
451 if (s.length() == 0 || s ==
"void") {
458 if (s.substr(0, 3) ==
"int") {
460 scan = s.c_str() + 3;
461 }
else if (s.substr(0, 4) ==
"uint") {
463 scan = s.c_str() + 4;
464 }
else if (s.substr(0, 13) ==
"float4_e2m1fn") {
468 scan = s.c_str() + 13;
469 char* endpt =
nullptr;
471 t.lanes =
static_cast<uint16_t
>(strtoul(
scan + 1, &endpt, 10));
474 ICHECK(
scan == s.c_str() + s.length()) <<
"unknown type " << s;
476 }
else if (s.substr(0, 13) ==
"float8_e4m3fn") {
480 scan = s.c_str() + 13;
481 char* endpt =
nullptr;
483 t.lanes =
static_cast<uint16_t
>(strtoul(
scan + 1, &endpt, 10));
486 ICHECK(
scan == s.c_str() + s.length()) <<
"unknown type " << s;
488 }
else if (s.substr(0, 11) ==
"float8_e5m2") {
492 scan = s.c_str() + 11;
493 char* endpt =
nullptr;
495 t.lanes =
static_cast<uint16_t
>(strtoul(
scan + 1, &endpt, 10));
498 ICHECK(
scan == s.c_str() + s.length()) <<
"unknown type " << s;
500 }
else if (s.substr(0, 5) ==
"float") {
502 scan = s.c_str() + 5;
503 }
else if (s.substr(0, 6) ==
"handle") {
506 scan = s.c_str() + 6;
507 }
else if (s ==
"bool") {
512 }
else if (s.substr(0, 6) ==
"bfloat") {
515 scan = s.c_str() + 6;
516 }
else if (s.substr(0, 6) ==
"custom") {
520 LOG(FATAL) <<
"unknown type " << s;
523 uint8_t bits =
static_cast<uint8_t
>(strtoul(
scan, &xdelim, 10));
524 if (bits != 0) t.bits = bits;
525 int scalable_multiplier = 1;
526 if (strncmp(xdelim,
"xvscale", 7) == 0) {
527 scalable_multiplier = -1;
530 char* endpt = xdelim;
531 if (*xdelim ==
'x') {
532 t.lanes =
static_cast<uint16_t
>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
534 ICHECK(endpt == s.c_str() + s.length()) <<
"unknown type " << s;
547 inline int cantor_pairing_function(
int a,
int b)
const {
return (a + b) * (a + b + 1) / 2 + b; }
549 int a = dtype.
code();
550 int b = dtype.
bits();
551 int c = dtype.
lanes();
552 int d = cantor_pairing_function(a, b);
553 return cantor_pairing_function(c, d);
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:170
int64_t tvm_index_t
type of array index.
Definition: c_runtime_api.h:89
Runtime primitive data type.
Definition: data_type.h:43
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:295
bool is_handle() const
Definition: data_type.h:149
bool is_uint() const
Definition: data_type.h:147
int get_lanes_or_vscale_factor() const
Definition: data_type.h:118
bool is_float8_e5m2() const
Definition: data_type.h:138
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:244
bool is_float4_e2m1fn() const
Definition: data_type.h:139
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:189
bool is_scalable_vector() const
Definition: data_type.h:158
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:193
int bytes() const
Definition: data_type.h:100
TypeCode
Type code for the DataType.
Definition: data_type.h:53
@ kHandle
Definition: data_type.h:57
@ kUInt
Definition: data_type.h:55
@ kBFloat
Definition: data_type.h:58
@ kFloat
Definition: data_type.h:56
@ kFloat8_e4m3fn
Definition: data_type.h:59
@ kCustomBegin
Definition: data_type.h:62
@ kFloat4_e2m1fn
Definition: data_type.h:61
@ kInt
Definition: data_type.h:54
@ kFloat8_e5m2
Definition: data_type.h:60
static DataType NVFloat8E4M3(int lanes=1)
Construct NV float8 e4m3 datatype.
Definition: data_type.h:257
bool is_bool() const
Definition: data_type.h:124
bool is_int() const
Definition: data_type.h:145
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:184
static DataType NVFloat8E5M2(int lanes=1)
Construct NV float8 e5m2 datatype.
Definition: data_type.h:263
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:214
DataType()
default constructor
Definition: data_type.h:65
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:151
int code() const
Definition: data_type.h:96
static DataType NVFloat4E2M1FN(int lanes=1)
Construct NV float4_e2m1fn datatype.
Definition: data_type.h:269
int lanes() const
Definition: data_type.h:102
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:78
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:205
bool is_float16() const
Definition: data_type.h:141
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:276
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:170
int vscale_factor() const
Definition: data_type.h:110
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:70
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:176
bool is_fixed_length_vector() const
Definition: data_type.h:156
bool is_bfloat() const
Definition: data_type.h:128
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:251
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:227
static DataType Void()
Construct a Void type.
Definition: data_type.h:290
bool is_vector() const
Definition: data_type.h:160
bool is_scalar() const
Definition: data_type.h:122
int bits() const
Definition: data_type.h:98
bool is_float8() const
Definition: data_type.h:130
bool is_bfloat16() const
Definition: data_type.h:143
bool is_float8_e4m3fn() const
Definition: data_type.h:137
bool is_vector_bool() const
Definition: data_type.h:162
bool is_float4() const
Definition: data_type.h:136
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:285
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:235
bool is_void() const
Definition: data_type.h:164
bool is_float() const
Definition: data_type.h:126
std::string GetCustomTypeName(uint8_t type_code)
Runtime utility for getting custom type name from code.
bool GetCustomTypeRegistered(uint8_t type_code)
Runtime utility for checking whether custom type is registered.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:448
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:441
uint8_t ParseCustomDatatype(const std::string &s, const char **scan)
Runtime utility for parsing string of the form "custom[<typename>]".
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:312
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:330
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:338
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
const char * DLDataTypeCode2Str(DLDataTypeCode type_code)
Convert type code to its name.
Definition: data_type.h:386
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ObjectRef > attrs={})
Construct new tensors by scan.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
runtime::DataType DataType
Definition: data_type.h:540