tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file tvm/runtime/data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
26 
28 #include <tvm/runtime/logging.h>
29 
30 #include <cstring>
31 #include <string>
32 #include <type_traits>
33 
34 namespace tvm {
35 namespace runtime {
36 
43 class DataType {
44  public:
53  enum TypeCode {
54  kInt = kDLInt,
55  kUInt = kDLUInt,
56  kFloat = kDLFloat,
58  kBFloat = kDLBfloat,
62  kCustomBegin = 129
63  };
65  DataType() { data_ = DataType::Void(); }
70  explicit DataType(DLDataType dtype) : data_(dtype) {}
78  DataType(int code, int bits, int lanes, bool is_scalable = false) {
79  data_.code = static_cast<uint8_t>(code);
80  data_.bits = static_cast<uint8_t>(bits);
81  if (is_scalable) {
82  ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes;
83  }
84  data_.lanes = is_scalable ? static_cast<uint16_t>(-lanes) : static_cast<uint16_t>(lanes);
85  if (code == kBFloat) {
86  ICHECK_EQ(bits, 16);
87  }
88  if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) {
89  ICHECK_EQ(bits, 8);
90  }
91  if (code == kFloat4_e2m1fn) {
92  ICHECK_EQ(bits, 4);
93  }
94  }
96  int code() const { return static_cast<int>(data_.code); }
98  int bits() const { return static_cast<int>(data_.bits); }
100  int bytes() const { return (bits() + 7) / 8; }
102  int lanes() const {
103  int lanes_as_int = static_cast<int16_t>(data_.lanes);
104  if (lanes_as_int < 0) {
105  LOG(FATAL) << "Can't fetch the lanes of a scalable vector at a compile time.";
106  }
107  return lanes_as_int;
108  }
110  int vscale_factor() const {
111  int lanes_as_int = static_cast<int16_t>(data_.lanes);
112  if (lanes_as_int >= -1) {
113  LOG(FATAL) << "A fixed length vector doesn't have a vscale factor.";
114  }
115  return -lanes_as_int;
116  }
119  return is_scalable_vector() ? vscale_factor() : lanes();
120  }
122  bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
124  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
126  bool is_float() const { return code() == DataType::kFloat; }
128  bool is_bfloat() const { return code() == DataType::kBFloat; }
130  bool is_float8() const {
131  return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
133  bits() == 8;
134  }
136  bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() == 4; }
137  bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn && bits() == 8); }
138  bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && bits() == 8); }
139  bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn && bits() == 4); }
141  bool is_float16() const { return is_float() && bits() == 16; }
143  bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
145  bool is_int() const { return code() == DataType::kInt; }
147  bool is_uint() const { return code() == DataType::kUInt; }
149  bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
152  int encoded_lanes = static_cast<int16_t>(data_.lanes);
153  return (encoded_lanes < -1) || (1 < encoded_lanes);
154  }
156  bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
158  bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
160  bool is_vector() const { return lanes() > 1; }
162  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
164  bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
170  DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
177  return DataType(data_.code, data_.bits, -vscale_factor);
178  }
184  DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
189  DataType element_of() const { return with_lanes(1); }
193  DataType& operator=(const DataType& rhs) {
194  if (this == &rhs) {
195  return *this;
196  }
197  data_ = rhs.data_;
198  return *this;
199  }
205  bool operator==(const DataType& other) const {
206  return data_.code == other.data_.code && data_.bits == other.data_.bits &&
207  data_.lanes == other.data_.lanes;
208  }
214  bool operator!=(const DataType& other) const { return !operator==(other); }
219  operator DLDataType() const { return data_; }
220 
227  static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
235  static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
236  return DataType(kDLUInt, bits, lanes, is_scalable);
237  }
244  static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
251  static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
257  static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
263  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
269  static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
276  static DataType Bool(int lanes = 1, bool is_scalable = false) {
277  return DataType::UInt(1, lanes, is_scalable);
278  }
285  static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
290  static DataType Void() { return DataType(kHandle, 0, 0); }
295  static DataType ShapeIndex() {
296  if (std::is_signed<tvm_index_t>::value) {
297  return DataType::Int(sizeof(tvm_index_t) * 8);
298  } else {
299  return DataType::UInt(sizeof(tvm_index_t) * 8);
300  }
301  }
302 
303  private:
304  DLDataType data_;
305 };
306 
312 inline int GetVectorBytes(DataType dtype) {
313  int data_bits = dtype.bits() * dtype.lanes();
314  // allow bool to exist
315  if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
316  dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
317  return 1;
318  }
319  ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
320  return data_bits / 8;
321 }
322 
330 inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
331  return t.code == code && t.bits == bits && t.lanes == lanes;
332 }
338 inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
339  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
340 }
341 
347 TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
348 
354 TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
355 
362 TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
363 
369 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);
370 
376 inline DLDataType String2DLDataType(std::string s);
377 
383 inline std::string DLDataType2String(DLDataType t);
384 
385 // implementation details
386 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
387  switch (static_cast<int>(type_code)) {
388  case kDLInt:
389  return "int";
390  case kDLUInt:
391  return "uint";
392  case kDLFloat:
393  return "float";
394  case DataType::kHandle:
395  return "handle";
396  case kDLBfloat:
397  return "bfloat";
399  return "float8_e4m3fn";
401  return "float8_e5m2";
403  return "float4_e2m1fn";
404  default:
405  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
406  }
407  throw;
408 }
409 
410 inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
411  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
412  os << "bool";
413  return os;
414  }
415  if (DataType(t).is_void()) {
416  return os << "void";
417  }
418  if (t.code < DataType::kCustomBegin) {
419  os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
420  } else {
421  os << "custom[" << GetCustomTypeName(t.code) << "]";
422  }
423  if (t.code == kTVMOpaqueHandle) return os;
424  int16_t lanes = static_cast<int16_t>(t.lanes);
425  if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 &&
426  t.code != DataType::kFloat4_e2m1fn) {
427  os << static_cast<int>(t.bits);
428  }
429  if (lanes > 1) {
430  os << 'x' << lanes;
431  } else if (lanes < -1) {
432  os << "xvscalex" << -lanes;
433  }
434  return os;
435 }
436 
437 inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
438  return os << dtype.operator DLDataType();
439 }
440 
441 inline std::string DLDataType2String(DLDataType t) {
442  if (t.bits == 0) return "";
443  std::ostringstream os;
444  os << t;
445  return os.str();
446 }
447 
448 inline DLDataType String2DLDataType(std::string s) {
449  DLDataType t;
450  // handle void type
451  if (s.length() == 0 || s == "void") {
452  t = DataType::Void();
453  return t;
454  }
455  t.bits = 32;
456  t.lanes = 1;
457  const char* scan;
458  if (s.substr(0, 3) == "int") {
459  t.code = kDLInt;
460  scan = s.c_str() + 3;
461  } else if (s.substr(0, 4) == "uint") {
462  t.code = kDLUInt;
463  scan = s.c_str() + 4;
464  } else if (s.substr(0, 13) == "float4_e2m1fn") {
465  // Avoid being treated as "float"
466  t.code = DataType::kFloat4_e2m1fn;
467  t.bits = 4;
468  scan = s.c_str() + 13;
469  char* endpt = nullptr;
470  if (*scan == 'x') {
471  t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
472  scan = endpt;
473  }
474  ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
475  return t;
476  } else if (s.substr(0, 13) == "float8_e4m3fn") {
477  // Avoid being treated as "float"
478  t.code = DataType::kFloat8_e4m3fn;
479  t.bits = 8;
480  scan = s.c_str() + 13;
481  char* endpt = nullptr;
482  if (*scan == 'x') {
483  t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
484  scan = endpt;
485  }
486  ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
487  return t;
488  } else if (s.substr(0, 11) == "float8_e5m2") {
489  // Avoid being treated as "float"
490  t.code = DataType::kFloat8_e5m2;
491  t.bits = 8;
492  scan = s.c_str() + 11;
493  char* endpt = nullptr;
494  if (*scan == 'x') {
495  t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
496  scan = endpt;
497  }
498  ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
499  return t;
500  } else if (s.substr(0, 5) == "float") {
501  t.code = kDLFloat;
502  scan = s.c_str() + 5;
503  } else if (s.substr(0, 6) == "handle") {
504  t.code = kTVMOpaqueHandle;
505  t.bits = 64; // handle uses 64 bit by default.
506  scan = s.c_str() + 6;
507  } else if (s == "bool") {
508  t.code = kDLUInt;
509  t.bits = 1;
510  t.lanes = 1;
511  return t;
512  } else if (s.substr(0, 6) == "bfloat") {
513  t.code = DataType::kBFloat;
514  t.bits = 16;
515  scan = s.c_str() + 6;
516  } else if (s.substr(0, 6) == "custom") {
517  t.code = ParseCustomDatatype(s, &scan);
518  } else {
519  scan = s.c_str();
520  LOG(FATAL) << "unknown type " << s;
521  }
522  char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
523  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
524  if (bits != 0) t.bits = bits;
525  int scalable_multiplier = 1;
526  if (strncmp(xdelim, "xvscale", 7) == 0) {
527  scalable_multiplier = -1;
528  xdelim += 7;
529  }
530  char* endpt = xdelim;
531  if (*xdelim == 'x') {
532  t.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
533  }
534  ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
535  return t;
536 }
537 
538 } // namespace runtime
539 
541 
542 } // namespace tvm
543 
544 namespace std {
545 template <>
546 struct hash<tvm::DataType> {
547  inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
548  std::size_t operator()(tvm::DataType const& dtype) const {
549  int a = dtype.code();
550  int b = dtype.bits();
551  int c = dtype.lanes();
552  int d = cantor_pairing_function(a, b);
553  return cantor_pairing_function(c, d);
554  }
555 };
556 } // namespace std
557 
558 #endif // TVM_RUNTIME_DATA_TYPE_H_
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:170
int64_t tvm_index_t
type of array index.
Definition: c_runtime_api.h:89
Runtime primitive data type.
Definition: data_type.h:43
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:295
bool is_handle() const
Definition: data_type.h:149
bool is_uint() const
Definition: data_type.h:147
int get_lanes_or_vscale_factor() const
Definition: data_type.h:118
bool is_float8_e5m2() const
Definition: data_type.h:138
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:244
bool is_float4_e2m1fn() const
Definition: data_type.h:139
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:189
bool is_scalable_vector() const
Definition: data_type.h:158
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:193
int bytes() const
Definition: data_type.h:100
TypeCode
Type code for the DataType.
Definition: data_type.h:53
@ kHandle
Definition: data_type.h:57
@ kUInt
Definition: data_type.h:55
@ kBFloat
Definition: data_type.h:58
@ kFloat
Definition: data_type.h:56
@ kFloat8_e4m3fn
Definition: data_type.h:59
@ kCustomBegin
Definition: data_type.h:62
@ kFloat4_e2m1fn
Definition: data_type.h:61
@ kInt
Definition: data_type.h:54
@ kFloat8_e5m2
Definition: data_type.h:60
static DataType NVFloat8E4M3(int lanes=1)
Construct NV float8 e4m3 datatype.
Definition: data_type.h:257
bool is_bool() const
Definition: data_type.h:124
bool is_int() const
Definition: data_type.h:145
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:184
static DataType NVFloat8E5M2(int lanes=1)
Construct NV float8 e5m2 datatype.
Definition: data_type.h:263
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:214
DataType()
default constructor
Definition: data_type.h:65
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:151
int code() const
Definition: data_type.h:96
static DataType NVFloat4E2M1FN(int lanes=1)
Construct NV float4_e2m1fn datatype.
Definition: data_type.h:269
int lanes() const
Definition: data_type.h:102
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:78
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:205
bool is_float16() const
Definition: data_type.h:141
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:276
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:170
int vscale_factor() const
Definition: data_type.h:110
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:70
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:176
bool is_fixed_length_vector() const
Definition: data_type.h:156
bool is_bfloat() const
Definition: data_type.h:128
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:251
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:227
static DataType Void()
Construct a Void type.
Definition: data_type.h:290
bool is_vector() const
Definition: data_type.h:160
bool is_scalar() const
Definition: data_type.h:122
int bits() const
Definition: data_type.h:98
bool is_float8() const
Definition: data_type.h:130
bool is_bfloat16() const
Definition: data_type.h:143
bool is_float8_e4m3fn() const
Definition: data_type.h:137
bool is_vector_bool() const
Definition: data_type.h:162
bool is_float4() const
Definition: data_type.h:136
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:285
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:235
bool is_void() const
Definition: data_type.h:164
bool is_float() const
Definition: data_type.h:126
std::string GetCustomTypeName(uint8_t type_code)
Runtime utility for getting custom type name from code.
bool GetCustomTypeRegistered(uint8_t type_code)
Runtime utility for checking whether custom type is registered.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:448
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:441
uint8_t ParseCustomDatatype(const std::string &s, const char **scan)
Runtime utility for parsing string of the form "custom[<typename>]".
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:312
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:330
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:338
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
const char * DLDataTypeCode2Str(DLDataTypeCode type_code)
Convert type code to its name.
Definition: data_type.h:386
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ObjectRef > attrs={})
Construct new tensors by scan.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
runtime::DataType DataType
Definition: data_type.h:540