Source code for tvm_ffi._dtype
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""dtype class."""
# pylint: disable=invalid-name
from __future__ import annotations
from enum import IntEnum
from typing import Any, ClassVar
from . import core
class DataTypeCode(IntEnum):
"""DLDataTypeCode code in DLTensor."""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
BFLOAT = 4
Float8E3M4 = 7
Float8E4M3 = 8
Float8E4M3B11FNUZ = 9
Float8E4M3FN = 10
Float8E4M3FNUZ = 11
Float8E5M2 = 12
Float8E5M2FNUZ = 13
Float8E8M0FNU = 14
Float6E2M3FN = 15
Float6E3M2FN = 16
Float4E2M1FN = 17
[docs]
class dtype(str):
"""TVM FFI dtype class.
Parameters
----------
dtype_str
Note
----
This class subclasses str so it can be directly passed
into other array api's dtype arguments.
"""
__slots__ = ["_tvm_ffi_dtype"]
_tvm_ffi_dtype: core.DataType
_NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
def __new__(cls, content: Any) -> dtype:
content = str(content)
val = str.__new__(cls, content)
val._tvm_ffi_dtype = core.DataType(content)
return val
[docs]
@staticmethod
def from_dlpack_data_type(dltype_data_type: tuple[int, int, int]) -> dtype:
"""Create a dtype from a DLPack data type tuple.
Parameters
----------
dltype_data_type
The DLPack data type tuple (type_code, bits, lanes).
Returns
-------
The created dtype.
"""
cdtype = core._create_dtype_from_tuple(
core.DataType,
dltype_data_type[0],
dltype_data_type[1],
dltype_data_type[2],
)
val = str.__new__(dtype, str(cdtype))
val._tvm_ffi_dtype = cdtype
return val
def __repr__(self) -> str:
return f"dtype('{self}')"
[docs]
def with_lanes(self, lanes: int) -> dtype:
"""Create a new dtype with the given number of lanes.
Parameters
----------
lanes
The number of lanes.
Returns
-------
dtype
The new dtype with the given number of lanes.
"""
cdtype = core._create_dtype_from_tuple(
core.DataType,
self._tvm_ffi_dtype.type_code,
self._tvm_ffi_dtype.bits,
lanes,
)
val = str.__new__(dtype, str(cdtype))
val._tvm_ffi_dtype = cdtype
return val
@property
def itemsize(self) -> int:
return self._tvm_ffi_dtype.itemsize
@property
def type_code(self) -> int:
return self._tvm_ffi_dtype.type_code
@property
def bits(self) -> int:
return self._tvm_ffi_dtype.bits
@property
def lanes(self) -> int:
return self._tvm_ffi_dtype.lanes
try:
# this helps to make numpy as optional
# although almost in all cases we want numpy
import numpy as np
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
if hasattr(np, "float_"):
dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64"
except ImportError:
pass
try:
import ml_dtypes
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
except ImportError:
pass
core._set_class_dtype(dtype)