tvm
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file tvm/runtime/data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
26 
28 #include <tvm/runtime/logging.h>
29 
30 #include <cstring>
31 #include <string>
32 #include <type_traits>
33 
34 namespace tvm {
35 namespace runtime {
36 
43 class DataType {
44  public:
53  enum TypeCode {
54  kInt = kDLInt,
55  kUInt = kDLUInt,
56  kFloat = kDLFloat,
58  kBFloat = kDLBfloat,
59  kE4M3Float = 6U,
60  kE5M2Float = 7U,
61  kCustomBegin = 129
62  };
64  DataType() { data_ = DataType::Void(); }
69  explicit DataType(DLDataType dtype) : data_(dtype) {}
77  DataType(int code, int bits, int lanes, bool is_scalable = false) {
78  data_.code = static_cast<uint8_t>(code);
79  data_.bits = static_cast<uint8_t>(bits);
80  if (is_scalable) {
81  ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes;
82  }
83  data_.lanes = is_scalable ? static_cast<uint16_t>(-lanes) : static_cast<uint16_t>(lanes);
84  if (code == kBFloat) {
85  ICHECK_EQ(bits, 16);
86  }
87  if (code == kE4M3Float || code == kE5M2Float) {
88  ICHECK_EQ(bits, 8);
89  }
90  }
92  int code() const { return static_cast<int>(data_.code); }
94  int bits() const { return static_cast<int>(data_.bits); }
96  int bytes() const { return (bits() + 7) / 8; }
98  int lanes() const {
99  int lanes_as_int = static_cast<int16_t>(data_.lanes);
100  if (lanes_as_int < 0) {
101  LOG(FATAL) << "Can't fetch the lanes of a scalable vector at a compile time.";
102  }
103  return lanes_as_int;
104  }
106  int vscale_factor() const {
107  int lanes_as_int = static_cast<int16_t>(data_.lanes);
108  if (lanes_as_int >= -1) {
109  LOG(FATAL) << "A fixed length vector doesn't have a vscale factor.";
110  }
111  return -lanes_as_int;
112  }
115  return is_scalable_vector() ? vscale_factor() : lanes();
116  }
118  bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
120  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
122  bool is_float() const { return code() == DataType::kFloat; }
124  bool is_float8() const {
125  return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
126  code() == DataType::kE5M2Float) &&
127  bits() == 8;
128  }
129  bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }
130 
131  bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
133  bool is_float16() const { return is_float() && bits() == 16; }
135  bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
137  bool is_int() const { return code() == DataType::kInt; }
139  bool is_uint() const { return code() == DataType::kUInt; }
141  bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
144  int encoded_lanes = static_cast<int16_t>(data_.lanes);
145  return (encoded_lanes < -1) || (1 < encoded_lanes);
146  }
148  bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
150  bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
152  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
154  bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
160  DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
167  return DataType(data_.code, data_.bits, -vscale_factor);
168  }
174  DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
179  DataType element_of() const { return with_lanes(1); }
183  DataType& operator=(const DataType& rhs) {
184  if (this == &rhs) {
185  return *this;
186  }
187  data_ = rhs.data_;
188  return *this;
189  }
195  bool operator==(const DataType& other) const {
196  return data_.code == other.data_.code && data_.bits == other.data_.bits &&
197  data_.lanes == other.data_.lanes;
198  }
204  bool operator!=(const DataType& other) const { return !operator==(other); }
209  operator DLDataType() const { return data_; }
210 
217  static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
225  static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
226  return DataType(kDLUInt, bits, lanes, is_scalable);
227  }
234  static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
241  static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
247  static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
253  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
260  static DataType Bool(int lanes = 1, bool is_scalable = false) {
261  return DataType::UInt(1, lanes, is_scalable);
262  }
269  static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
274  static DataType Void() { return DataType(kHandle, 0, 0); }
279  static DataType ShapeIndex() {
280  if (std::is_signed<tvm_index_t>::value) {
281  return DataType::Int(sizeof(tvm_index_t) * 8);
282  } else {
283  return DataType::UInt(sizeof(tvm_index_t) * 8);
284  }
285  }
286 
287  private:
288  DLDataType data_;
289 };
290 
296 inline int GetVectorBytes(DataType dtype) {
297  int data_bits = dtype.bits() * dtype.lanes();
298  // allow bool to exist
299  if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
300  dtype == DataType::Int(1)) {
301  return 1;
302  }
303  ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
304  return data_bits / 8;
305 }
306 
314 inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
315  return t.code == code && t.bits == bits && t.lanes == lanes;
316 }
322 inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
323  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
324 }
325 
331 TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
332 
338 TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
339 
346 TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
347 
353 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);
354 
360 inline DLDataType String2DLDataType(std::string s);
361 
367 inline std::string DLDataType2String(DLDataType t);
368 
369 // implementation details
370 inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
371  switch (static_cast<int>(type_code)) {
372  case kDLInt:
373  return "int";
374  case kDLUInt:
375  return "uint";
376  case kDLFloat:
377  return "float";
378  case DataType::kHandle:
379  return "handle";
380  case kDLBfloat:
381  return "bfloat";
383  return "e4m3_float";
385  return "e5m2_float";
386  default:
387  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
388  }
389  throw;
390 }
391 
392 inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
393  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
394  os << "bool";
395  return os;
396  }
397  if (DataType(t).is_void()) {
398  return os << "void";
399  }
400  if (t.code < DataType::kCustomBegin) {
401  os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
402  } else {
403  os << "custom[" << GetCustomTypeName(t.code) << "]";
404  }
405  if (t.code == kTVMOpaqueHandle) return os;
406  int16_t lanes = static_cast<int16_t>(t.lanes);
407  os << static_cast<int>(t.bits);
408  if (lanes > 1) {
409  os << 'x' << lanes;
410  } else if (lanes < -1) {
411  os << "xvscalex" << -lanes;
412  }
413  return os;
414 }
415 
416 inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
417  return os << dtype.operator DLDataType();
418 }
419 
420 inline std::string DLDataType2String(DLDataType t) {
421  if (t.bits == 0) return "";
422  std::ostringstream os;
423  os << t;
424  return os.str();
425 }
426 
427 inline DLDataType String2DLDataType(std::string s) {
428  DLDataType t;
429  // handle void type
430  if (s.length() == 0 || s == "void") {
431  t = DataType::Void();
432  return t;
433  }
434  t.bits = 32;
435  t.lanes = 1;
436  const char* scan;
437  if (s.substr(0, 3) == "int") {
438  t.code = kDLInt;
439  scan = s.c_str() + 3;
440  } else if (s.substr(0, 4) == "uint") {
441  t.code = kDLUInt;
442  scan = s.c_str() + 4;
443  } else if (s.substr(0, 5) == "float") {
444  t.code = kDLFloat;
445  scan = s.c_str() + 5;
446  } else if (s.substr(0, 6) == "handle") {
447  t.code = kTVMOpaqueHandle;
448  t.bits = 64; // handle uses 64 bit by default.
449  scan = s.c_str() + 6;
450  } else if (s == "bool") {
451  t.code = kDLUInt;
452  t.bits = 1;
453  t.lanes = 1;
454  return t;
455  } else if (s.substr(0, 6) == "bfloat") {
456  t.code = DataType::kBFloat;
457  t.bits = 16;
458  scan = s.c_str() + 6;
459  } else if (s.substr(0, 10) == "e4m3_float") {
460  t.code = DataType::kE4M3Float;
461  t.bits = 8;
462  scan = s.c_str() + 10;
463  } else if (s.substr(0, 10) == "e5m2_float") {
464  t.code = DataType::kE5M2Float;
465  t.bits = 8;
466  scan = s.c_str() + 10;
467  } else if (s.substr(0, 6) == "custom") {
468  t.code = ParseCustomDatatype(s, &scan);
469  } else {
470  scan = s.c_str();
471  LOG(FATAL) << "unknown type " << s;
472  }
473  char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
474  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
475  if (bits != 0) t.bits = bits;
476  int scalable_multiplier = 1;
477  if (strncmp(xdelim, "xvscale", 7) == 0) {
478  scalable_multiplier = -1;
479  xdelim += 7;
480  }
481  char* endpt = xdelim;
482  if (*xdelim == 'x') {
483  t.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
484  }
485  ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
486  return t;
487 }
488 
489 } // namespace runtime
490 
492 
493 } // namespace tvm
494 
495 namespace std {
496 template <>
497 struct hash<tvm::DataType> {
498  inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
499  std::size_t operator()(tvm::DataType const& dtype) const {
500  int a = dtype.code();
501  int b = dtype.bits();
502  int c = dtype.lanes();
503  int d = cantor_pairing_function(a, b);
504  return cantor_pairing_function(c, d);
505  }
506 };
507 } // namespace std
508 
509 #endif // TVM_RUNTIME_DATA_TYPE_H_
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:177
int64_t tvm_index_t
type of array index.
Definition: c_runtime_api.h:88
Runtime primitive data type.
Definition: data_type.h:43
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:279
bool is_handle() const
Definition: data_type.h:141
bool is_uint() const
Definition: data_type.h:139
int get_lanes_or_vscale_factor() const
Definition: data_type.h:114
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:234
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:179
bool is_scalable_vector() const
Definition: data_type.h:150
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:183
int bytes() const
Definition: data_type.h:96
TypeCode
Type code for the DataType.
Definition: data_type.h:53
@ kHandle
Definition: data_type.h:57
@ kUInt
Definition: data_type.h:55
@ kBFloat
Definition: data_type.h:58
@ kFloat
Definition: data_type.h:56
@ kE4M3Float
Definition: data_type.h:59
@ kCustomBegin
Definition: data_type.h:61
@ kE5M2Float
Definition: data_type.h:60
@ kInt
Definition: data_type.h:54
static DataType NVFloat8E4M3(int lanes=1)
Construct NV float8 e4m3 datatype.
Definition: data_type.h:247
bool is_bool() const
Definition: data_type.h:120
bool is_int() const
Definition: data_type.h:137
bool is_e4m3_float8() const
Definition: data_type.h:129
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:174
static DataType NVFloat8E5M2(int lanes=1)
Construct NV float8 e5m2 datatype.
Definition: data_type.h:253
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:204
DataType()
default constructor
Definition: data_type.h:64
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:143
int code() const
Definition: data_type.h:92
int lanes() const
Definition: data_type.h:98
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:77
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:195
bool is_float16() const
Definition: data_type.h:133
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:260
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:160
int vscale_factor() const
Definition: data_type.h:106
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:69
bool is_e5m2_float8() const
Definition: data_type.h:131
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:166
bool is_fixed_length_vector() const
Definition: data_type.h:148
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:241
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:217
static DataType Void()
Construct a Void type.
Definition: data_type.h:274
bool is_scalar() const
Definition: data_type.h:118
int bits() const
Definition: data_type.h:94
bool is_float8() const
Definition: data_type.h:124
bool is_bfloat16() const
Definition: data_type.h:135
bool is_vector_bool() const
Definition: data_type.h:152
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:269
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:225
bool is_void() const
Definition: data_type.h:154
bool is_float() const
Definition: data_type.h:122
std::string GetCustomTypeName(uint8_t type_code)
Runtime utility for getting custom type name from code.
bool GetCustomTypeRegistered(uint8_t type_code)
Runtime utility for checking whether custom type is registered.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:427
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:420
uint8_t ParseCustomDatatype(const std::string &s, const char **scan)
Runtime utility for parsing string of the form "custom[<typename>]".
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:296
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:314
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:322
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
const char * DLDataTypeCode2Str(DLDataTypeCode type_code)
Convert type code to its name.
Definition: data_type.h:370
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ObjectRef > attrs={})
Construct new tensors by scan.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
runtime::DataType DataType
Definition: data_type.h:491