tvm
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file tvm/runtime/data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_RUNTIME_DATA_TYPE_H_
25 #define TVM_RUNTIME_DATA_TYPE_H_
26 
27 #include <tvm/ffi/container/shape.h>
28 #include <tvm/ffi/dtype.h>
29 #include <tvm/runtime/base.h>
30 #include <tvm/runtime/logging.h>
31 
32 #include <cstring>
33 #include <string>
34 #include <type_traits>
35 
36 namespace tvm {
37 namespace runtime {
38 
39 using tvm_index_t = ffi::Shape::index_type;
40 
47 class DataType {
48  public:
57  enum TypeCode {
58  kInt = kDLInt,
59  kUInt = kDLUInt,
60  kFloat = kDLFloat,
61  kHandle = kDLOpaqueHandle,
62  kBFloat = kDLBfloat,
63  kBool = kDLBool,
64  kFloat8_e3m4 = kDLFloat8_e3m4,
65  kFloat8_e4m3 = kDLFloat8_e4m3,
66  kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
67  kFloat8_e4m3fn = kDLFloat8_e4m3fn,
68  kFloat8_e4m3fnuz = kDLFloat8_e4m3fnuz,
69  kFloat8_e5m2 = kDLFloat8_e5m2,
70  kFloat8_e5m2fnuz = kDLFloat8_e5m2fnuz,
71  kFloat8_e8m0fnu = kDLFloat8_e8m0fnu,
72  kFloat6_e2m3fn = kDLFloat6_e2m3fn,
73  kFloat6_e3m2fn = kDLFloat6_e3m2fn,
74  kFloat4_e2m1fn = kDLFloat4_e2m1fn,
75  kCustomBegin = 129
76  };
78  DataType() { data_ = DataType::Void(); }
83  explicit DataType(DLDataType dtype) : data_(dtype) {}
91  DataType(int code, int bits, int lanes, bool is_scalable = false) {
92  data_.code = static_cast<uint8_t>(code);
93  data_.bits = static_cast<uint8_t>(bits);
94  if (is_scalable) {
95  TVM_FFI_ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes;
96  }
97  data_.lanes = is_scalable ? static_cast<uint16_t>(-lanes) : static_cast<uint16_t>(lanes);
98  if (code == kBFloat) {
99  TVM_FFI_ICHECK_EQ(bits, 16);
100  }
104  TVM_FFI_ICHECK_EQ(bits, 8);
105  }
106  if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) {
107  TVM_FFI_ICHECK_EQ(bits, 6);
108  }
109  if (code == kFloat4_e2m1fn) {
110  TVM_FFI_ICHECK_EQ(bits, 4);
111  }
112  }
114  int code() const { return static_cast<int>(data_.code); }
116  int bits() const { return static_cast<int>(data_.bits); }
118  int bytes() const { return (bits() + 7) / 8; }
120  int lanes() const {
121  int lanes_as_int = static_cast<int16_t>(data_.lanes);
122  if (lanes_as_int < 0) {
123  TVM_FFI_THROW(InternalError)
124  << "Can't fetch the lanes of a scalable vector at a compile time.";
125  }
126  return lanes_as_int;
127  }
129  int vscale_factor() const {
130  int lanes_as_int = static_cast<int16_t>(data_.lanes);
131  if (lanes_as_int >= -1) {
132  TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor.";
133  }
134  return -lanes_as_int;
135  }
138  return is_scalable_vector() ? vscale_factor() : lanes();
139  }
141  bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
143  bool is_bool() const { return code() == DataType::kBool; }
145  bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); }
147  bool is_float() const { return code() == DataType::kFloat; }
149  bool is_bfloat() const { return code() == DataType::kBFloat; }
151  bool is_float8() const {
152  return bits() == 8 &&
157  }
159  bool is_float6() const {
160  return bits() == 6 &&
162  }
164  bool is_float4() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; }
166  bool is_float8_e3m4() const { return bits() == 8 && code() == DataType::kFloat8_e3m4; }
168  bool is_float8_e4m3() const { return bits() == 8 && code() == DataType::kFloat8_e4m3; }
170  bool is_float8_e4m3b11fnuz() const {
171  return bits() == 8 && code() == DataType::kFloat8_e4m3b11fnuz;
172  }
174  bool is_float8_e4m3fn() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fn; }
176  bool is_float8_e4m3fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fnuz; }
178  bool is_float8_e5m2() const { return bits() == 8 && code() == DataType::kFloat8_e5m2; }
180  bool is_float8_e5m2fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e5m2fnuz; }
182  bool is_float8_e8m0fnu() const { return bits() == 8 && code() == DataType::kFloat8_e8m0fnu; }
184  bool is_float6_e2m3fn() const { return bits() == 6 && code() == DataType::kFloat6_e2m3fn; }
186  bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; }
188  bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; }
190  bool is_float16() const { return is_float() && bits() == 16; }
192  bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
194  bool is_int() const { return code() == DataType::kInt; }
196  bool is_uint() const { return code() == DataType::kUInt; }
198  bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
201  int encoded_lanes = static_cast<int16_t>(data_.lanes);
202  return (encoded_lanes < -1) || (1 < encoded_lanes);
203  }
205  bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
207  bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
209  bool is_vector() const { return lanes() > 1; }
213  bool is_void() const {
214  return code() == DataType::kHandle && bits() == 0 && static_cast<int16_t>(data_.lanes) == 0;
215  }
221  DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
228  return DataType(data_.code, data_.bits, -vscale_factor);
229  }
235  DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
240  DataType element_of() const { return with_lanes(1); }
244  DataType& operator=(const DataType& rhs) {
245  if (this == &rhs) {
246  return *this;
247  }
248  data_ = rhs.data_;
249  return *this;
250  }
256  bool operator==(const DataType& other) const {
257  return data_.code == other.data_.code && data_.bits == other.data_.bits &&
258  data_.lanes == other.data_.lanes;
259  }
265  bool operator!=(const DataType& other) const { return !operator==(other); }
270  operator DLDataType() const { return data_; }
271 
278  static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
286  static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
287  return DataType(kDLUInt, bits, lanes, is_scalable);
288  }
295  static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
302  static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
308  static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); }
309 
315  static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); }
316 
322  static DataType Float8E4M3B11FNUZ(int lanes = 1) {
323  return DataType(kFloat8_e4m3b11fnuz, 8, lanes);
324  }
325 
331  static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
332 
338  static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); }
339 
345  static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
346 
352  static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); }
353 
359  static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); }
360 
366  static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); }
367 
373  static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); }
374 
380  static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
387  static DataType Bool(int lanes = 1, bool is_scalable = false) {
388  return DataType(kDLBool, 8, lanes, is_scalable);
389  }
396  static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
401  static DataType Void() { return DataType(kHandle, 0, 0); }
406  static DataType ShapeIndex() {
407  if (std::is_signed<tvm_index_t>::value) {
408  return DataType::Int(sizeof(tvm_index_t) * 8);
409  } else {
410  return DataType::UInt(sizeof(tvm_index_t) * 8);
411  }
412  }
413 
414  private:
415  DLDataType data_;
416 };
417 
423 inline int GetVectorBytes(DataType dtype) {
424  int data_bits = dtype.bits() * dtype.lanes();
425  // allow bool to exist
426  if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
427  dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() ||
428  dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) {
429  return 1;
430  }
431  TVM_FFI_ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
432  return data_bits / 8;
433 }
434 
442 inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
443  return t.code == code && t.bits == bits && t.lanes == lanes;
444 }
450 inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
451  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
452 }
453 
454 using ffi::DLDataTypeToString;
455 using ffi::StringToDLDataType;
456 
457 inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
458  return os << dtype.operator DLDataType();
459 }
460 } // namespace runtime
461 
463 
464 namespace ffi {
465 
466 // runtime::DataType
467 template <>
468 struct TypeTraits<runtime::DataType> : public TypeTraitsBase {
469  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType;
470 
471  TVM_FFI_INLINE static void CopyToAnyView(const runtime::DataType& src, TVMFFIAny* result) {
472  // clear padding part to ensure the equality check can always check the v_uint64 part
473  result->v_uint64 = 0;
474  result->zero_padding = 0;
475  result->type_index = TypeIndex::kTVMFFIDataType;
476  result->v_dtype = src;
477  }
478 
479  TVM_FFI_INLINE static void MoveToAny(runtime::DataType src, TVMFFIAny* result) {
480  // clear padding part to ensure the equality check can always check the v_uint64 part
481  result->v_uint64 = 0;
482  result->zero_padding = 0;
483  result->type_index = TypeIndex::kTVMFFIDataType;
484  result->v_dtype = src;
485  }
486 
487  TVM_FFI_INLINE static std::optional<runtime::DataType> TryCastFromAnyView(const TVMFFIAny* src) {
488  auto opt_dtype = TypeTraits<DLDataType>::TryCastFromAnyView(src);
489  if (opt_dtype) {
490  return runtime::DataType(opt_dtype.value());
491  }
492  return std::nullopt;
493  }
494 
495  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
496  return TypeTraits<DLDataType>::CheckAnyStrict(src);
497  }
498 
499  TVM_FFI_INLINE static runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
500  return runtime::DataType(TypeTraits<DLDataType>::CopyFromAnyViewAfterCheck(src));
501  }
502 
503  TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; }
504 
505  TVM_FFI_INLINE static std::string TypeSchema() {
506  return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})";
507  }
508 };
509 
510 } // namespace ffi
511 } // namespace tvm
512 
513 namespace std {
514 template <>
515 struct hash<tvm::DataType> {
516  inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
517  std::size_t operator()(tvm::DataType const& dtype) const {
518  int a = dtype.code();
519  int b = dtype.bits();
520  int c = dtype.lanes();
521  int d = cantor_pairing_function(a, b);
522  return cantor_pairing_function(c, d);
523  }
524 };
525 } // namespace std
526 
527 #endif // TVM_RUNTIME_DATA_TYPE_H_
Runtime primitive data type.
Definition: data_type.h:47
static DataType ShapeIndex()
Get the corresponding type of TVMShapeIndex.
Definition: data_type.h:406
bool is_handle() const
Definition: data_type.h:198
static DataType Float8E4M3FNUZ(int lanes=1)
Construct float8 e4m3fnuz datatype.
Definition: data_type.h:338
bool is_uint() const
Definition: data_type.h:196
int get_lanes_or_vscale_factor() const
Definition: data_type.h:137
bool is_float8_e5m2() const
Definition: data_type.h:178
static DataType Float4E2M1FN(int lanes=1)
Construct float4 e2m1fn datatype.
Definition: data_type.h:380
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
bool is_float4_e2m1fn() const
Definition: data_type.h:188
bool is_float6() const
Definition: data_type.h:159
static DataType Float8E8M0FNU(int lanes=1)
Construct float8 e8m0fnu datatype.
Definition: data_type.h:359
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:240
bool is_scalable_vector() const
Definition: data_type.h:207
bool is_float8_e4m3() const
Definition: data_type.h:168
DataType & operator=(const DataType &rhs)
Assignment operator.
Definition: data_type.h:244
static DataType Float8E4M3FN(int lanes=1)
Construct float8 e4m3fn datatype.
Definition: data_type.h:331
int bytes() const
Definition: data_type.h:118
TypeCode
Type code for the DataType.
Definition: data_type.h:57
@ kFloat8_e4m3b11fnuz
Definition: data_type.h:66
@ kBool
Definition: data_type.h:63
@ kHandle
Definition: data_type.h:61
@ kUInt
Definition: data_type.h:59
@ kFloat8_e4m3
Definition: data_type.h:65
@ kFloat6_e3m2fn
Definition: data_type.h:73
@ kFloat6_e2m3fn
Definition: data_type.h:72
@ kBFloat
Definition: data_type.h:62
@ kFloat
Definition: data_type.h:60
@ kFloat8_e4m3fn
Definition: data_type.h:67
@ kFloat8_e4m3fnuz
Definition: data_type.h:68
@ kCustomBegin
Definition: data_type.h:75
@ kFloat8_e8m0fnu
Definition: data_type.h:71
@ kFloat4_e2m1fn
Definition: data_type.h:74
@ kInt
Definition: data_type.h:58
@ kFloat8_e3m4
Definition: data_type.h:64
@ kFloat8_e5m2fnuz
Definition: data_type.h:70
@ kFloat8_e5m2
Definition: data_type.h:69
bool is_float6_e3m2fn() const
Definition: data_type.h:186
static DataType Float8E5M2FNUZ(int lanes=1)
Construct float8 e5m2fnuz datatype.
Definition: data_type.h:352
bool is_float8_e3m4() const
Definition: data_type.h:166
bool is_bool() const
Definition: data_type.h:143
static DataType Float8E4M3B11FNUZ(int lanes=1)
Construct float8 e4m3b11fnuz datatype.
Definition: data_type.h:322
bool is_int() const
Definition: data_type.h:194
DataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:235
bool operator!=(const DataType &other) const
NotEqual comparator.
Definition: data_type.h:265
DataType()
default constructor
Definition: data_type.h:78
bool is_scalable_or_fixed_length_vector() const
Definition: data_type.h:200
int code() const
Definition: data_type.h:114
int lanes() const
Definition: data_type.h:120
static DataType Float8E5M2(int lanes=1)
Construct float8 e5m2 datatype.
Definition: data_type.h:345
DataType(int code, int bits, int lanes, bool is_scalable=false)
Constructor.
Definition: data_type.h:91
bool is_float8_e8m0fnu() const
Definition: data_type.h:182
static DataType Float6E2M3FN(int lanes=1)
Construct float6 e2m3fn datatype.
Definition: data_type.h:366
bool operator==(const DataType &other) const
Equal comparator.
Definition: data_type.h:256
bool is_float16() const
Definition: data_type.h:190
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:387
DataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:221
static DataType Float6E3M2FN(int lanes=1)
Construct float6 e3m2fn datatype.
Definition: data_type.h:373
int vscale_factor() const
Definition: data_type.h:129
DataType(DLDataType dtype)
Constructor.
Definition: data_type.h:83
bool is_float8_e4m3b11fnuz() const
Definition: data_type.h:170
DataType with_scalable_vscale_factor(int vscale_factor) const
Create a new scalable vector data type by changing the vscale multiplier to a specified value....
Definition: data_type.h:227
bool is_float8_e4m3fnuz() const
Definition: data_type.h:176
bool is_fixed_length_vector() const
Definition: data_type.h:205
bool is_predicate_dtype() const
Definition: data_type.h:145
bool is_bfloat() const
Definition: data_type.h:149
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:302
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
static DataType Void()
Construct a Void type.
Definition: data_type.h:401
bool is_vector() const
Definition: data_type.h:209
bool is_float6_e2m3fn() const
Definition: data_type.h:184
bool is_scalar() const
Definition: data_type.h:141
int bits() const
Definition: data_type.h:116
bool is_float8() const
Definition: data_type.h:151
bool is_bfloat16() const
Definition: data_type.h:192
bool is_float8_e4m3fn() const
Definition: data_type.h:174
bool is_vector_bool() const
Definition: data_type.h:211
bool is_float4() const
Definition: data_type.h:164
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:396
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:286
static DataType Float8E4M3(int lanes=1)
Construct float8 e4m3 datatype.
Definition: data_type.h:315
static DataType Float8E3M4(int lanes=1)
Construct float8 e3m4 datatype.
Definition: data_type.h:308
bool is_void() const
Definition: data_type.h:213
bool is_float8_e5m2fnuz() const
Definition: data_type.h:180
bool is_float() const
Definition: data_type.h:147
std::ostream & operator<<(std::ostream &os, const DataType &dtype)
Definition: data_type.h:457
ffi::Shape::index_type tvm_index_t
Definition: data_type.h:39
int GetVectorBytes(DataType dtype)
Get the number of bytes needed in a vector.
Definition: data_type.h:423
bool TypeMatch(DLDataType t, int code, int bits, int lanes=1)
Check whether type matches the given spec.
Definition: data_type.h:442
bool TypeEqual(DLDataType lhs, DLDataType rhs)
Check whether two types are equal .
Definition: data_type.h:450
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
runtime::DataType DataType
Definition: data_type.h:462
static TVM_FFI_INLINE void MoveToAny(runtime::DataType src, TVMFFIAny *result)
Definition: data_type.h:479
static TVM_FFI_INLINE void CopyToAnyView(const runtime::DataType &src, TVMFFIAny *result)
Definition: data_type.h:471
static TVM_FFI_INLINE std::string TypeSchema()
Definition: data_type.h:505
static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny *src)
Definition: data_type.h:495
static TVM_FFI_INLINE runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: data_type.h:499
static TVM_FFI_INLINE std::string TypeStr()
Definition: data_type.h:503
static TVM_FFI_INLINE std::optional< runtime::DataType > TryCastFromAnyView(const TVMFFIAny *src)
Definition: data_type.h:487