# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container classes."""
from __future__ import annotations
import itertools
import operator
from collections.abc import ItemsView as ItemsViewBase
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import KeysView as KeysViewBase
from collections.abc import ValuesView as ValuesViewBase
from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload
from . import _ffi_api, core
from .registry import register_object
__all__ = ["Array", "Map"]
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
_DefaultT = TypeVar("_DefaultT")
def getitem_helper(
obj: Any,
elem_getter: Callable[[Any, int], T],
length: int,
idx: SupportsIndex | slice,
) -> T | list[T]:
"""Implement a pythonic __getitem__ helper.
Parameters
----------
obj: Any
The original object
elem_getter : Callable[[Any, int], T]
A simple function that takes index and return a single element.
length : int
The size of the array
idx : SupportsIndex or slice
The argument passed to getitem
Returns
-------
result : object
The element for integer indices or a ``list`` for slices.
"""
if isinstance(idx, slice):
start, stop, step = idx.indices(length)
return [elem_getter(obj, i) for i in range(start, stop, step)]
try:
index = operator.index(idx)
except TypeError as exc: # pragma: no cover - defensive, matches list behaviour
raise TypeError(f"indices must be integers or slices, not {type(idx).__name__}") from exc
if index < -length or index >= length:
raise IndexError(f"Index out of range. size: {length}, got index {index}")
if index < 0:
index += length
return elem_getter(obj, index)
[docs]
@register_object("ffi.Array")
class Array(core.Object, Sequence[T]):
"""Array container that represents a sequence of values in ffi.
:py:func:`tvm_ffi.convert` will map python list/tuple to this class.
Parameters
----------
input_list : Iterable[T]
The list of values to be stored in the array.
See Also
--------
:py:func:`tvm_ffi.convert`
Examples
--------
.. code-block:: python
import tvm_ffi
a = tvm_ffi.convert([1, 2, 3])
assert isinstance(a, tvm_ffi.Array)
assert len(a) == 3
"""
[docs]
def __init__(self, input_list: Iterable[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@overload
def __getitem__(self, idx: SupportsIndex, /) -> T: ...
@overload
def __getitem__(self, idx: slice, /) -> list[T]: ...
def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]:
"""Return one element or a list for a slice."""
length = len(self)
result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
return result
def __len__(self) -> int:
"""Return the number of elements in the array."""
return _ffi_api.ArraySize(self)
def __iter__(self) -> Iterator[T]:
"""Iterate over the elements in the array."""
length = len(self)
for i in range(length):
yield self[i]
def __repr__(self) -> str:
"""Return a string representation of the array."""
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
return "[" + ", ".join([x.__repr__() for x in self]) + "]"
def __add__(self, other: Iterable[T]) -> Array[T]:
"""Concatenate two arrays."""
return type(self)(itertools.chain(self, other))
def __radd__(self, other: Iterable[T]) -> Array[T]:
"""Concatenate two arrays."""
return type(self)(itertools.chain(other, self))
class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[K]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
yield key
if not functor(2):
break
def __contains__(self, k: object) -> bool:
return k in self._backend_map
class ValuesView(ValuesViewBase[V]):
"""Helper class to return values view."""
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[V]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
value = cast(V, functor(1))
yield value
if not functor(2):
break
class ItemsView(ItemsViewBase[K, V]):
"""Helper class to return items view."""
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[tuple[K, V]]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
value = cast(V, functor(1))
yield (key, value)
if not functor(2):
break
def __contains__(self, item: object) -> bool:
if not isinstance(item, tuple) or len(item) != 2:
return False
key, value = item
try:
existing_value = self._backend_map[key]
except KeyError:
return False
else:
return existing_value == value
[docs]
@register_object("ffi.Map")
class Map(core.Object, Mapping[K, V]):
"""Map container.
:py:func:`tvm_ffi.convert` will map python dict to this class.
Parameters
----------
input_dict : Mapping[K, V]
The dictionary of values to be stored in the map.
See Also
--------
:py:func:`tvm_ffi.convert`
Examples
--------
.. code-block:: python
import tvm_ffi
amap = tvm_ffi.convert({"a": 1, "b": 2})
assert isinstance(amap, tvm_ffi.Map)
assert len(amap) == 2
assert amap["a"] == 1
assert amap["b"] == 2
"""
[docs]
def __init__(self, input_dict: Mapping[K, V]) -> None:
"""Construct a Map from a Python mapping."""
list_kvs: list[Any] = []
for k, v in input_dict.items():
list_kvs.append(k)
list_kvs.append(v)
self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)
def __getitem__(self, k: K) -> V:
"""Return the value for key `k` or raise KeyError."""
return cast(V, _ffi_api.MapGetItem(self, k))
def __contains__(self, k: object) -> bool:
"""Return True if the map contains key `k`."""
return _ffi_api.MapCount(self, k) != 0
[docs]
def keys(self) -> KeysView[K]:
"""Return a dynamic view of the map's keys."""
return KeysView(self)
[docs]
def values(self) -> ValuesView[V]:
"""Return a dynamic view of the map's values."""
return ValuesView(self)
[docs]
def items(self) -> ItemsView[K, V]:
"""Get the items from the map."""
return ItemsView(self)
def __len__(self) -> int:
"""Return the number of items in the map."""
return _ffi_api.MapSize(self)
def __iter__(self) -> Iterator[K]:
"""Iterate over the map's keys."""
return iter(self.keys())
@overload
def get(self, key: K) -> V | None: ...
@overload
def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
[docs]
def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | None:
"""Get an element with a default value.
Parameters
----------
key : object
The attribute key.
default : object
The default object.
Returns
-------
value: object
The result value.
"""
try:
return self[key]
except KeyError:
return default
def __repr__(self) -> str:
"""Return a string representation of the map."""
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}"