tvm
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file tvm/runtime/data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
26 
28 #include <tvm/runtime/logging.h>
29 
30 #include <cstring>
31 #include <string>
32 #include <type_traits>
33 
34 namespace tvm {
35 namespace runtime {
36 
43 class DataType {
44  public:
53  enum TypeCode {
54  kInt = kDLInt,
55  kUInt = kDLUInt,
56  kFloat = kDLFloat,
58  kBFloat = kDLBfloat,
59  kE4M3Float = 6U,
60  kE5M2Float = 7U,
61  kCustomBegin = 129
62  };
64  DataType() { data_ = DataType::Void(); }
69  explicit DataType(DLDataType dtype) : data_(dtype) {}
77  DataType(int code, int bits, int lanes, bool is_scalable = false) {
78  data_.code = static_cast<uint8_t>(code);
79  data_.bits = static_cast<uint8_t>(bits);
80  if (is_scalable) {
81  ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes;
82  }
83  data_.lanes = is_scalable ? static_cast<uint16_t>(-lanes) : static_cast<uint16_t>(lanes);
84  if (code == kBFloat) {
85  ICHECK_EQ(bits, 16);
86  }
87  if (code == kE4M3Float || code == kE5M2Float) {
88  ICHECK_EQ(bits, 8);
89  }
90  }
92  int code() const { return static_cast<int>(data_.code); }
94  int bits() const { return static_cast<int>(data_.bits); }
96  int bytes() const { return (bits() + 7) / 8; }
98  int lanes() const {
99  int lanes_as_int = static_cast<int16_t>(data_.lanes);
100  if (lanes_as_int < 0) {
101  LOG(FATAL) << "Can't fetch the lanes of a scalable vector at a compile time.";
102  }
103  return lanes_as_int;
104  }
106  int vscale_factor() const {
107  int lanes_as_int = static_cast<int16_t>(data_.lanes);
108  if (lanes_as_int >= -1) {
109  LOG(FATAL) << "A fixed length vector doesn't have a vscale factor.";
110  }
111  return -lanes_as_int;
112  }
115  return is_scalable_vector() ? vscale_factor() : lanes();
116  }
118  bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
120  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
122  bool is_float() const { return code() == DataType::kFloat; }
124  bool is_float8() const {
125  return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
126  code() == DataType::kE5M2Float) &&
127  bits() == 8;
128  }
129  bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }
130 
131  bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
133  bool is_float16() const { return is_float() && bits() == 16; }
135  bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
137  bool is_int() const { return code() == DataType::kInt; }
139  bool is_uint() const { return code() == DataType::kUInt; }
141  bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
144  int encoded_lanes = static_cast<int16_t>(data_.lanes);
145  return (encoded_lanes < -1) || (1 < encoded_lanes);
146  }
148  bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
150  bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
152  bool is_vector() const { return lanes() > 1; }
154  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
156  bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
162  DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
169  return DataType(data_.code, data_.bits, -vscale_factor);
170  }
176  DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
181  DataType element_of() const { return with_lanes(1); }
185  DataType& operator=(const DataType& rhs) {
186  if (this == &rhs) {
187  return *this;
188  }
189  data_ = rhs.data_;
190  return *this;
191  }
197  bool operator==(const DataType& other) const {
198  return data_.code == other.data_.code && data_.bits == other.data_.bits &&
199  data_.lanes == other.data_.lanes;
200  }
206  bool operator!=(const DataType& other) const { return !operator==(other); }
211  operator DLDataType() const { return data_; }
212 
219  static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
227  static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
228  return DataType(kDLUInt, bits, lanes, is_scalable);
229  }
236  static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
243  static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
249  static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
255  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
262  static DataType Bool(int lanes = 1, bool is_scalable = false) {
263  return DataType::UInt(1, lanes, is_scalable);
264  }
271  static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
276  static DataType Void() { return DataType(kHandle, 0, 0); }
281  static DataType ShapeIndex() {
282  if (std::is_signed<tvm_index_t>::value) {
283  return DataType::Int(sizeof(tvm_index_t) * 8);
284  } else {
285  return DataType::UInt(sizeof(tvm_index_t) * 8);
286  }
287  }
288 
289  private:
290  DLDataType data_;
291 };
292 
298 inline int GetVectorBytes(DataType dtype) {
299  int data_bits = dtype.bits() * dtype.lanes();
300  // allow bool to exist
301  if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
302  dtype == DataType::Int(1)) {
303  return 1;
304  }
305  ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
306  return data_bits / 8;
307 }
308 
316 inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
317  return t.code == code && t.bits == bits && t.lanes == lanes;
318 }
324 inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
325  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
326 }
327 
333 TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
334 
340 TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
341 
348 TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
349 
355 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);
356 
362 inline DLDataType String2DLDataType(std::string s);
363 
369 inline std::string DLDataType2String(DLDataType t);
370 
371 // implementation details
372 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
373  switch (static_cast<int>(type_code)) {
374  case kDLInt:
375  return "int";
376  case kDLUInt:
377  return "uint";
378  case kDLFloat:
379  return "float";
380  case DataType::kHandle:
381  return "handle";
382  case kDLBfloat:
383  return "bfloat";
385  return "e4m3_float";
387  return "e5m2_float";
388  default:
389  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
390  }
391  throw;
392 }
393 
394 inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
395  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
396  os << "bool";
397  return os;
398  }
399  if (DataType(t).is_void()) {
400  return os << "void";
401  }
402  if (t.code < DataType::kCustomBegin) {
403  os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
404  } else {
405  os << "custom[" << GetCustomTypeName(t.code) << "]";
406  }
407  if (t.code == kTVMOpaqueHandle) return os;
408  int16_t lanes = static_cast<int16_t>(t.lanes);
409  os << static_cast<int>(t.bits);
410  if (lanes > 1) {
411  os << 'x' << lanes;
412  } else if (lanes < -1) {
413  os << "xvscalex" << -lanes;
414  }
415  return os;
416 }
417 
418 inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
419  return os << dtype.operator DLDataType();
420 }
421 
422 inline std::string DLDataType2String(DLDataType t) {
423  if (t.bits == 0) return "";
424  std::ostringstream os;
425  os << t;
426  return os.str();
427 }
428 
429 inline DLDataType String2DLDataType(std::string s) {
430  DLDataType t;
431  // handle void type
432  if (s.length() == 0 || s == "void") {
433  t = DataType::Void();
434  return t;
435  }
436  t.bits = 32;
437  t.lanes = 1;
438  const char* scan;
439  if (s.substr(0, 3) == "int") {
440  t.code = kDLInt;
441  scan = s.c_str() + 3;
442  } else if (s.substr(0, 4) == "uint") {
443  t.code = kDLUInt;
444  scan = s.c_str() + 4;
445  } else if (s.substr(0, 5) == "float") {
446  t.code = kDLFloat;
447  scan = s.c_str() + 5;
448  } else if (s.substr(0, 6) == "handle") {
449  t.code = kTVMOpaqueHandle;
450  t.bits = 64; // handle uses 64 bit by default.
451  scan = s.c_str() + 6;
452  } else if (s == "bool") {
453  t.code = kDLUInt;
454  t.bits = 1;
455  t.lanes = 1;
456  return t;
457  } else if (s.substr(0, 6) == "bfloat") {
458  t.code = DataType::kBFloat;
459  t.bits = 16;
460  scan = s.c_str() + 6;
461  } else if (s.substr(0, 10) == "e4m3_float") {
462  t.code = DataType::kE4M3Float;
463  t.bits = 8;
464  scan = s.c_str() + 10;
465  } else if (s.substr(0, 10) == "e5m2_float") {
466  t.code = DataType::kE5M2Float;
467  t.bits = 8;
468  scan = s.c_str() + 10;
469  } else if (s.substr(0, 6) == "custom") {
470  t.code = ParseCustomDatatype(s, &scan);
471  } else {
472  scan = s.c_str();
473  LOG(FATAL) << "unknown type " << s;
474  }
475  char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
476  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
477  if (bits != 0) t.bits = bits;
478  int scalable_multiplier = 1;
479  if (strncmp(xdelim, "xvscale", 7) == 0) {
480  scalable_multiplier = -1;
481  xdelim += 7;
482  }
483  char* endpt = xdelim;
484  if (*xdelim == 'x') {
485  t.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
486  }
487  ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
488  return t;
489 }
490 
491 } // namespace runtime
492 
494 
495 } // namespace tvm
496 
497 namespace std {
498 template <>
499 struct hash<tvm::DataType> {
500  inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
501  std::size_t operator()(tvm::DataType const& dtype) const {
502  int a = dtype.code();
503  int b = dtype.bits();
504  int c = dtype.lanes();
505  int d = cantor_pairing_function(a, b);
506  return cantor_pairing_function(c, d);
507  }
508 };
509 } // namespace std
510 
511 #endif // TVM_RUNTIME_DATA_TYPE_H_
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:178
int64_t tvm_index_t
type of array index.
Definition: c_runtime_api.h:89
Runtime primitive data type.
Definition: data_type.h:43
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:281
bool is_handle() const
Definition: data_type.h:141
bool is_uint() const
Definition: data_type.h:139
int get_lanes_or_vscale_factor() const
Definition: data_type.h:114
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:181
bool is_scalable_vector() const
Definition: data_type.h:150
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:185
int bytes() const
Definition: data_type.h:96
TypeCode
Type code for the DataType.
Definition: data_type.h:53
@ kHandle
Definition: data_type.h:57
@ kUInt
Definition: data_type.h:55
@ kBFloat
Definition: data_type.h:58
@ kFloat
Definition: data_type.h:56
@ kE4M3Float
Definition: data_type.h:59
@ kCustomBegin
Definition: data_type.h:61
@ kE5M2Float
Definition: data_type.h:60
@ kInt
Definition: data_type.h:54
static DataType NVFloat8E4M3(int lanes=1)
Construct NV float8 e4m3 datatype.
Definition: data_type.h:249
bool is_bool() const
Definition: data_type.h:120
bool is_int() const
Definition: data_type.h:137
bool is_e4m3_float8() const
Definition: data_type.h:129
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:176
static DataType NVFloat8E5M2(int lanes=1)
Construct NV float8 e5m2 datatype.
Definition: data_type.h:255
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:206
DataType()
default constructor
Definition: data_type.h:64
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:143
int code() const
Definition: data_type.h:92
int lanes() const
Definition: data_type.h:98
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:77
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:197
bool is_float16() const
Definition: data_type.h:133
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:262
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:162
int vscale_factor() const
Definition: data_type.h:106
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:69
bool is_e5m2_float8() const
Definition: data_type.h:131
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:168
bool is_fixed_length_vector() const
Definition: data_type.h:148
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:243
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
static DataType Void()
Construct a Void type.
Definition: data_type.h:276
bool is_vector() const
Definition: data_type.h:152
bool is_scalar() const
Definition: data_type.h:118
int bits() const
Definition: data_type.h:94
bool is_float8() const
Definition: data_type.h:124
bool is_bfloat16() const
Definition: data_type.h:135
bool is_vector_bool() const
Definition: data_type.h:154
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:271
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:227
bool is_void() const
Definition: data_type.h:156
bool is_float() const
Definition: data_type.h:122
std::string GetCustomTypeName(uint8_t type_code)
Runtime utility for getting custom type name from code.
bool GetCustomTypeRegistered(uint8_t type_code)
Runtime utility for checking whether custom type is registered.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:429
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:422
uint8_t ParseCustomDatatype(const std::string &s, const char **scan)
Runtime utility for parsing string of the form "custom[<typename>]".
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:298
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:316
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:324
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
const char * DLDataTypeCode2Str(DLDataTypeCode type_code)
Convert type code to its name.
Definition: data_type.h:372
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ObjectRef > attrs={})
Construct new tensors by scan.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
runtime::DataType DataType
Definition: data_type.h:493