Source code for tvm_ffi._tensor

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor related objects and functions."""

from __future__ import annotations

# we name it as _tensor.py to avoid potential future case
# if we also want to expose a tensor function in the root namespace
from numbers import Integral
from typing import Any

from . import _ffi_api, core, registry
from .core import (
    Device,
    DLDeviceType,
    PyNativeObject,
    Tensor,
    _shape_obj_get_py_tuple,
    from_dlpack,
)


[docs] @registry.register_object("ffi.Shape") class Shape(tuple, PyNativeObject): """Shape tuple that represents ``ffi::Shape`` returned by an FFI call. Notes ----- This class subclasses :class:`tuple` so it can be used in most places where :class:`tuple` is used in Python array APIs. """ __tvm_ffi_object__: Any def __new__(cls, content: tuple[int, ...]) -> Shape: if any(not isinstance(x, Integral) for x in content): raise ValueError("Shape must be a tuple of integers") val: Shape = tuple.__new__(cls, content) val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) return val # pylint: disable=no-self-argument def __from_tvm_ffi_object__(cls, obj: Any) -> Shape: """Construct from a given tvm object.""" content = _shape_obj_get_py_tuple(obj) val: Shape = tuple.__new__(cls, content) # type: ignore[arg-type] val.__tvm_ffi_object__ = obj # type: ignore[attr-defined] return val
def device(device_type: str | int | DLDeviceType, index: int | None = None) -> Device: """Construct a TVM FFI device with given device type and index. Parameters ---------- device_type: str or int The device type or name. index: int, optional The device index. Returns ------- device: tvm_ffi.Device Examples -------- Device can be used to create reflection of device by string representation of the device type. .. code-block:: python assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) """ # must refer to core._CLASS_DEVICE so we pick up override here return core._CLASS_DEVICE(device_type, index) __all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]