Source code for tvm_ffi._dtype

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Lightweight dtype wrapper for TVM FFI."""

# pylint: disable=invalid-name
from __future__ import annotations

from enum import IntEnum
from typing import Any, ClassVar

from . import core


class DataTypeCode(IntEnum):
    """DLDataTypeCode code in DLTensor."""

    INT = 0
    UINT = 1
    FLOAT = 2
    HANDLE = 3
    BFLOAT = 4
    Float8E3M4 = 7
    Float8E4M3 = 8
    Float8E4M3B11FNUZ = 9
    Float8E4M3FN = 10
    Float8E4M3FNUZ = 11
    Float8E5M2 = 12
    Float8E5M2FNUZ = 13
    Float8E8M0FNU = 14
    Float6E2M3FN = 15
    Float6E3M2FN = 16
    Float4E2M1FN = 17


[docs] class dtype(str): """Lightweight data type in TVM FFI. It behaves like a Python :class:`str` but also carries an internal FFI representation. You can construct it from strings, NumPy/ML dtypes, or via :py:meth:`from_dlpack_data_type`. Parameters ---------- dtype_str The string representation of the dtype. Examples -------- .. code-block:: python import tvm_ffi # Create from string f32 = tvm_ffi.dtype("float32") assert f32.bits == 32 assert f32.itemsize == 4 # Adjust lanes to create vector types v4f32 = f32.with_lanes(4) assert v4f32 == "float32x4" # Round-trip from a DLPack (code, bits, lanes) triple f16 = tvm_ffi.dtype.from_dlpack_data_type((2, 16, 1)) assert f16 == "float16" Note ---- This class subclasses str so it can be directly passed into other array api's dtype arguments. """ __slots__ = ["_tvm_ffi_dtype"] _tvm_ffi_dtype: core.DataType _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {} def __new__(cls, content: Any) -> dtype: content = str(content) val = str.__new__(cls, content) val._tvm_ffi_dtype = core.DataType(content) return val
[docs] @staticmethod def from_dlpack_data_type(dltype_data_type: tuple[int, int, int]) -> dtype: """Create a dtype from a DLPack data type tuple. Parameters ---------- dltype_data_type The DLPack data type tuple ``(type_code, bits, lanes)``. Returns ------- dtype The created dtype. Examples -------- .. code-block:: python import tvm_ffi # Create float16 and int8 directly from DLPack triples f16 = tvm_ffi.dtype.from_dlpack_data_type((2, 16, 1)) i8 = tvm_ffi.dtype.from_dlpack_data_type((0, 8, 1)) assert f16 == "float16" assert i8 == "int8" See Also -------- :py:class:`tvm_ffi.dtype` User-facing dtype wrapper. :py:meth:`tvm_ffi.dtype.with_lanes` Create vector dtypes from a scalar base. """ cdtype = core._create_cdtype_from_tuple( core.DataType, dltype_data_type[0], dltype_data_type[1], dltype_data_type[2], ) val = str.__new__(dtype, str(cdtype)) val._tvm_ffi_dtype = cdtype return val
def __repr__(self) -> str: return f"dtype('{self}')"
[docs] def with_lanes(self, lanes: int) -> dtype: """Create a new dtype with the given number of lanes. Parameters ---------- lanes The number of lanes for the resulting vector type. Returns ------- dtype The new dtype with the given number of lanes. Examples -------- .. code-block:: python import tvm_ffi f32 = tvm_ffi.dtype("float32") v4f32 = f32.with_lanes(4) assert v4f32 == "float32x4" assert v4f32.bits == f32.bits and v4f32.lanes == 4 See Also -------- :py:meth:`tvm_ffi.dtype.from_dlpack_data_type` Construct from a DLPack ``(code, bits, lanes)`` triple. """ cdtype = core._create_cdtype_from_tuple( core.DataType, self._tvm_ffi_dtype.type_code, self._tvm_ffi_dtype.bits, lanes, ) val = str.__new__(dtype, str(cdtype)) val._tvm_ffi_dtype = cdtype return val
@property def itemsize(self) -> int: """Size of one element in bytes. The size is computed as ``bits * lanes // 8``. When the number of lanes is greater than 1, the ``itemsize`` represents the byte size of the vector element. Examples -------- .. code-block:: python import tvm_ffi assert tvm_ffi.dtype("float32").itemsize == 4 assert tvm_ffi.dtype("float32").with_lanes(4).itemsize == 16 See Also -------- :py:attr:`tvm_ffi.dtype.bits` Bit width of the scalar base type. :py:attr:`tvm_ffi.dtype.lanes` Number of lanes for vector types. :py:meth:`tvm_ffi.dtype.with_lanes` Create a vector dtype from a scalar base. """ return self._tvm_ffi_dtype.itemsize @property def type_code(self) -> int: """Integer DLDataTypeCode of the scalar base type. Examples -------- .. code-block:: python import tvm_ffi f32 = tvm_ffi.dtype("float32") # The type code is an integer following DLPack conventions assert isinstance(f32.type_code, int) # Consistent with constructing from an explicit (code, bits, lanes) assert f32.type_code == tvm_ffi.dtype.from_dlpack_data_type((2, 32, 1)).type_code See Also -------- :py:meth:`tvm_ffi.dtype.from_dlpack_data_type` Construct a dtype from a DLPack ``(code, bits, lanes)`` triple. """ return self._tvm_ffi_dtype.type_code @property def bits(self) -> int: """Number of bits of the scalar base type. Examples -------- .. code-block:: python import tvm_ffi assert tvm_ffi.dtype("int8").bits == 8 v4f32 = tvm_ffi.dtype("float32").with_lanes(4) assert v4f32.bits == 32 # per-lane bit width See Also -------- :py:attr:`tvm_ffi.dtype.itemsize` Byte size accounting for lanes. :py:attr:`tvm_ffi.dtype.lanes` Number of lanes for vector types. """ return self._tvm_ffi_dtype.bits @property def lanes(self) -> int: """Number of lanes (for vector types). Returns ``1`` for scalar dtypes and the lane count for vector dtypes created via :py:meth:`tvm_ffi.dtype.with_lanes`. Examples -------- .. code-block:: python import tvm_ffi assert tvm_ffi.dtype("float32").lanes == 1 assert tvm_ffi.dtype("float32").with_lanes(4).lanes == 4 See Also -------- :py:meth:`tvm_ffi.dtype.with_lanes` Create a vector dtype from a scalar base. :py:attr:`tvm_ffi.dtype.itemsize` Byte size accounting for lanes. """ return self._tvm_ffi_dtype.lanes
try: # this helps to make numpy as optional # although almost in all cases we want numpy import numpy as np dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" if hasattr(np, "float_"): dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" except ImportError: pass try: import ml_dtypes dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" if hasattr(ml_dtypes, "float4_e2m1fn"): # ml_dtypes >= 0.5.0 dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" except ImportError: pass core._set_class_dtype(dtype)