Source code for tvm_ffi.container

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container classes."""

from __future__ import annotations

import itertools
import operator
from collections.abc import ItemsView as ItemsViewBase
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import KeysView as KeysViewBase
from collections.abc import ValuesView as ValuesViewBase
from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload

from . import _ffi_api, core
from .registry import register_object

__all__ = ["Array", "Map"]


T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
_DefaultT = TypeVar("_DefaultT")


def getitem_helper(
    obj: Any,
    elem_getter: Callable[[Any, int], T],
    length: int,
    idx: SupportsIndex | slice,
) -> T | list[T]:
    """Implement a pythonic __getitem__ helper.

    Parameters
    ----------
    obj: Any
        The original object

    elem_getter : Callable[[Any, int], T]
        A simple function that takes index and return a single element.

    length : int
        The size of the array

    idx : SupportsIndex or slice
        The argument passed to getitem

    Returns
    -------
    result : object
        The element for integer indices or a ``list`` for slices.

    """
    if isinstance(idx, slice):
        start, stop, step = idx.indices(length)
        return [elem_getter(obj, i) for i in range(start, stop, step)]

    try:
        index = operator.index(idx)
    except TypeError as exc:  # pragma: no cover - defensive, matches list behaviour
        raise TypeError(f"indices must be integers or slices, not {type(idx).__name__}") from exc

    if index < -length or index >= length:
        raise IndexError(f"Index out of range. size: {length}, got index {index}")
    if index < 0:
        index += length
    return elem_getter(obj, index)


[docs] @register_object("ffi.Array") class Array(core.Object, Sequence[T]): """Array container that represents a sequence of values in ffi. :py:func:`tvm_ffi.convert` will map python list/tuple to this class. Parameters ---------- input_list : Iterable[T] The list of values to be stored in the array. See Also -------- :py:func:`tvm_ffi.convert` Examples -------- .. code-block:: python import tvm_ffi a = tvm_ffi.convert([1, 2, 3]) assert isinstance(a, tvm_ffi.Array) assert len(a) == 3 """
[docs] def __init__(self, input_list: Iterable[T]) -> None: """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@overload def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload def __getitem__(self, idx: slice, /) -> list[T]: ... def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: """Return one element or a list for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) return result def __len__(self) -> int: """Return the number of elements in the array.""" return _ffi_api.ArraySize(self) def __iter__(self) -> Iterator[T]: """Iterate over the elements in the array.""" length = len(self) for i in range(length): yield self[i] def __repr__(self) -> str: """Return a string representation of the array.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" return "[" + ", ".join([x.__repr__() for x in self]) + "]" def __add__(self, other: Iterable[T]) -> Array[T]: """Concatenate two arrays.""" return type(self)(itertools.chain(self, other)) def __radd__(self, other: Iterable[T]) -> Array[T]: """Concatenate two arrays.""" return type(self)(itertools.chain(other, self))
class KeysView(KeysViewBase[K]): """Helper class to return keys view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[K]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): key = cast(K, functor(0)) yield key if not functor(2): break def __contains__(self, k: object) -> bool: return k in self._backend_map class ValuesView(ValuesViewBase[V]): """Helper class to return values view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[V]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): value = cast(V, functor(1)) yield value if not functor(2): break class ItemsView(ItemsViewBase[K, V]): """Helper class to return items view.""" def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) def __iter__(self) -> Iterator[tuple[K, V]]: size = len(self._backend_map) functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) for _ in range(size): key = cast(K, functor(0)) value = cast(V, functor(1)) yield (key, value) if not functor(2): break def __contains__(self, item: object) -> bool: if not isinstance(item, tuple) or len(item) != 2: return False key, value = item try: existing_value = self._backend_map[key] except KeyError: return False else: return existing_value == value
[docs] @register_object("ffi.Map") class Map(core.Object, Mapping[K, V]): """Map container. :py:func:`tvm_ffi.convert` will map python dict to this class. Parameters ---------- input_dict : Mapping[K, V] The dictionary of values to be stored in the map. See Also -------- :py:func:`tvm_ffi.convert` Examples -------- .. code-block:: python import tvm_ffi amap = tvm_ffi.convert({"a": 1, "b": 2}) assert isinstance(amap, tvm_ffi.Map) assert len(amap) == 2 assert amap["a"] == 1 assert amap["b"] == 2 """
[docs] def __init__(self, input_dict: Mapping[K, V]) -> None: """Construct a Map from a Python mapping.""" list_kvs: list[Any] = [] for k, v in input_dict.items(): list_kvs.append(k) list_kvs.append(v) self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
def __getitem__(self, k: K) -> V: """Return the value for key `k` or raise KeyError.""" return cast(V, _ffi_api.MapGetItem(self, k)) def __contains__(self, k: object) -> bool: """Return True if the map contains key `k`.""" return _ffi_api.MapCount(self, k) != 0
[docs] def keys(self) -> KeysView[K]: """Return a dynamic view of the map's keys.""" return KeysView(self)
[docs] def values(self) -> ValuesView[V]: """Return a dynamic view of the map's values.""" return ValuesView(self)
[docs] def items(self) -> ItemsView[K, V]: """Get the items from the map.""" return ItemsView(self)
def __len__(self) -> int: """Return the number of items in the map.""" return _ffi_api.MapSize(self) def __iter__(self) -> Iterator[K]: """Iterate over the map's keys.""" return iter(self.keys()) @overload def get(self, key: K) -> V | None: ... @overload def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
[docs] def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | None: """Get an element with a default value. Parameters ---------- key : object The attribute key. default : object The default object. Returns ------- value: object The result value. """ try: return self[key] except KeyError: return default
def __repr__(self) -> str: """Return a string representation of the map.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}"