Source code for tvm_ffi.container

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container classes."""

from __future__ import annotations

import itertools
import operator
import sys
from typing import (
    Any,
    Callable,
    SupportsIndex,
    TypeVar,
    cast,
    overload,
)

from . import _ffi_api, core
from .registry import register_object

if sys.version_info >= (3, 9):
    # PEP 585 generics
    from collections.abc import (
        ItemsView as ItemsViewBase,
    )
    from collections.abc import (
        Iterable,
        Iterator,
        Mapping,
        Sequence,
    )
    from collections.abc import (
        KeysView as KeysViewBase,
    )
    from collections.abc import (
        ValuesView as ValuesViewBase,
    )
else:  # Python 3.8
    # workarounds for python 3.8
    # typing-module generics (subscriptable on 3.8)
    from typing import (
        ItemsView as ItemsViewBase,
    )
    from typing import (
        Iterable,
        Iterator,
        Mapping,
        Sequence,
    )
    from typing import (
        KeysView as KeysViewBase,
    )
    from typing import (
        ValuesView as ValuesViewBase,
    )

__all__ = ["Array", "Map"]


T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
_DefaultT = TypeVar("_DefaultT")


def getitem_helper(
    obj: Any,
    elem_getter: Callable[[Any, int], T],
    length: int,
    idx: SupportsIndex | slice,
) -> T | list[T]:
    """Implement a pythonic __getitem__ helper.

    Parameters
    ----------
    obj
        The original object

    elem_getter
        A simple function that takes index and return a single element.

    length
        The size of the array

    idx
        The argument passed to getitem

    Returns
    -------
    result
        The element for integer indices or a ``list`` for slices.

    """
    if isinstance(idx, slice):
        start, stop, step = idx.indices(length)
        return [elem_getter(obj, i) for i in range(start, stop, step)]

    try:
        index = operator.index(idx)
    except TypeError as exc:  # pragma: no cover - defensive, matches list behaviour
        raise TypeError(f"indices must be integers or slices, not {type(idx).__name__}") from exc

    if index < -length or index >= length:
        raise IndexError(f"Index out of range. size: {length}, got index {index}")
    if index < 0:
        index += length
    return elem_getter(obj, index)


[docs] @register_object("ffi.Array") class Array(core.Object, Sequence[T]): """Array container that represents a sequence of values in the FFI. :py:func:`tvm_ffi.convert` will map python list/tuple to this class. Parameters ---------- input_list The list of values to be stored in the array. Examples -------- .. code-block:: python import tvm_ffi a = tvm_ffi.Array([1, 2, 3]) assert tuple(a) == (1, 2, 3) See Also -------- :py:func:`tvm_ffi.convert` """
[docs] def __init__(self, input_list: Iterable[T]) -> None: """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@overload def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload def __getitem__(self, idx: slice, /) -> list[T]: ... def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: """Return one element or a list for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) return result def __len__(self) -> int: """Return the number of elements in the array.""" return _ffi_api.ArraySize(self) def __iter__(self) -> Iterator[T]: """Iterate over the elements in the array.""" length = len(self) for i in range(length): yield self[i] def __repr__(self) -> str: """Return a string representation of the array.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" return "[" + ", ".join([x.__repr__() for x in self]) + "]" def __add__(self, other: Iterable[T]) -> Array[T]: """Concatenate two arrays.""" return type(self)(itertools.chain(self, other)) def __radd__(self, other: Iterable[T]) -> Array[T]: """Concatenate two arrays.""" return type(self)(itertools.chain(other, self))
class KeysView(KeysViewBase[K]): """Helper class to return keys view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[K]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): key = cast(K, functor(0)) yield key if not functor(2): break def __contains__(self, k: object) -> bool: return k in self._backend_map class ValuesView(ValuesViewBase[V]): """Helper class to return values view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[V]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): value = cast(V, functor(1)) yield value if not functor(2): break class ItemsView(ItemsViewBase[K, V]): """Helper class to return items view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[tuple[K, V]]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): key = cast(K, functor(0)) value = cast(V, functor(1)) yield (key, value) if not functor(2): break def __contains__(self, item: object) -> bool: if not isinstance(item, tuple) or len(item) != 2: return False key, value = item try: existing_value = self._backend_map[key] except KeyError: return False else: return existing_value == value
[docs] @register_object("ffi.Map") class Map(core.Object, Mapping[K, V]): """Map container. :py:func:`tvm_ffi.convert` will map python dict to this class. Parameters ---------- input_dict The dictionary of values to be stored in the map. Examples -------- .. code-block:: python import tvm_ffi amap = tvm_ffi.Map({"a": 1, "b": 2}) assert len(amap) == 2 assert amap["a"] == 1 assert amap["b"] == 2 See Also -------- :py:func:`tvm_ffi.convert` """
[docs] def __init__(self, input_dict: Mapping[K, V]) -> None: """Construct a Map from a Python mapping.""" list_kvs: list[Any] = [] for k, v in input_dict.items(): list_kvs.append(k) list_kvs.append(v) self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
def __getitem__(self, k: K) -> V: """Return the value for key `k` or raise KeyError.""" return cast(V, _ffi_api.MapGetItem(self, k)) def __contains__(self, k: object) -> bool: """Return True if the map contains key `k`.""" return _ffi_api.MapCount(self, k) != 0
[docs] def keys(self) -> KeysView[K]: """Return a dynamic view of the map's keys.""" return KeysView(self)
[docs] def values(self) -> ValuesView[V]: """Return a dynamic view of the map's values.""" return ValuesView(self)
[docs] def items(self) -> ItemsView[K, V]: """Get the items from the map.""" return ItemsView(self)
def __len__(self) -> int: """Return the number of items in the map.""" return _ffi_api.MapSize(self) def __iter__(self) -> Iterator[K]: """Iterate over the map's keys.""" return iter(self.keys()) @overload def get(self, key: K) -> V | None: ... @overload def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
[docs] def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | None: """Get an element with a default value. Parameters ---------- key The attribute key. default The default object. Returns ------- value The result value. """ try: return self[key] except KeyError: return default
def __repr__(self) -> str: """Return a string representation of the map.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}"