# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI registry to register function and objects."""
from __future__ import annotations
import functools
import json
import sys
from typing import Any, Callable, Literal, Sequence, TypeVar, overload
from . import core
from .core import TypeInfo
# whether we simplify skip unknown objects regtistration
_SKIP_UNKNOWN_OBJECTS = False
_T = TypeVar("_T", bound=type)
[docs]
def register_object(type_key: str | None = None) -> Callable[[_T], _T]:
"""Register object type.
Parameters
----------
type_key
The type key of the node. It requires ``type_key`` to be registered already
on the C++ side. If not specified, the class name will be used.
Examples
--------
The following code registers MyObject using type key "test.MyObject", if the
type key is already registered on the C++ side.
.. code-block:: python
@tvm_ffi.register_object("test.MyObject")
class MyObject(Object):
pass
"""
def _register(cls: _T, object_name: str) -> _T:
"""Register the object type with the FFI core."""
type_index = core._object_type_key_to_index(object_name)
if type_index is None:
if _SKIP_UNKNOWN_OBJECTS:
return cls
raise ValueError(f"Cannot find object type index for {object_name}")
info = core._register_object_by_index(type_index, cls)
_add_class_attrs(type_cls=cls, type_info=info)
setattr(cls, "__tvm_ffi_type_info__", info)
return cls
if isinstance(type_key, str):
def _decorator_with_name(cls: _T) -> _T:
return _register(cls, type_key)
return _decorator_with_name
def _decorator_default(cls: _T) -> _T:
return _register(cls, cls.__name__)
if type_key is None:
return _decorator_default
if isinstance(type_key, type):
return _decorator_default(type_key)
raise TypeError("type_key must be a string, type, or None")
[docs]
def register_global_func(
func_name: str | Callable[..., Any],
f: Callable[..., Any] | None = None,
override: bool = False,
) -> Any:
"""Register global function.
Parameters
----------
func_name
The function name
f
The function to be registered.
override
Whether override existing entry.
Returns
-------
fregister
Register function if f is not specified.
Examples
--------
.. code-block:: python
import tvm_ffi
# we can use decorator to register a function
@tvm_ffi.register_global_func("mytest.echo")
def echo(x):
return x
# After registering, we can get the function by its name
f = tvm_ffi.get_global_func("mytest.echo")
assert f(1) == 1
# we can also directly register a function
tvm_ffi.register_global_func("mytest.add_one", lambda x: x + 1)
f = tvm_ffi.get_global_func("mytest.add_one")
assert f(1) == 2
See Also
--------
:py:func:`tvm_ffi.get_global_func`
:py:func:`tvm_ffi.remove_global_func`
"""
if not isinstance(func_name, str):
f = func_name
func_name = f.__name__ # ty: ignore[unresolved-attribute]
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(myf: Callable[..., Any]) -> Any:
"""Register the global function with the FFI core."""
return core._register_global_func(func_name, myf, override)
if f is not None:
return register(f)
return register
@overload
def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function | None: ...
@overload
def get_global_func(name: str, allow_missing: Literal[False] = False) -> core.Function: ...
[docs]
def get_global_func(name: str, allow_missing: bool = False) -> core.Function | None:
"""Get a global function by name.
Parameters
----------
name
The name of the global function
allow_missing
Whether allow missing function or raise an error.
Returns
-------
func
The function to be returned, ``None`` if function is missing.
Examples
--------
.. code-block:: python
import tvm_ffi
@tvm_ffi.register_global_func("demo.echo")
def echo(x):
return x
f = tvm_ffi.get_global_func("demo.echo")
assert f(123) == 123
See Also
--------
:py:func:`tvm_ffi.register_global_func`
"""
return core._get_global_func(name, allow_missing)
def list_global_func_names() -> list[str]:
"""Get list of global functions registered.
Returns
-------
names
List of global functions names.
"""
name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")()
num_names = name_functor(-1)
return [name_functor(i) for i in range(num_names)]
[docs]
def remove_global_func(name: str) -> None:
"""Remove a global function by name.
Parameters
----------
name
The name of the global function.
Examples
--------
.. code-block:: python
import tvm_ffi
@tvm_ffi.register_global_func("my.temp")
def temp():
return 42
assert tvm_ffi.get_global_func("my.temp", allow_missing=True) is not None
tvm_ffi.remove_global_func("my.temp")
assert tvm_ffi.get_global_func("my.temp", allow_missing=True) is None
See Also
--------
:py:func:`tvm_ffi.register_global_func`
:py:func:`tvm_ffi.get_global_func`
"""
get_global_func("ffi.FunctionRemoveGlobal")(name)
[docs]
def init_ffi_api(namespace: str, target_module_name: str | None = None) -> None:
"""Initialize register ffi api functions into a given module.
Parameters
----------
namespace
The namespace of the source registry
target_module_name
The target module name if different from namespace
Examples
--------
A typical usage pattern is to create a _ffi_api.py file to register
the functions under a given module. The following
code populates all registered global functions
prefixed with ``mypackage.`` into the current module,
then we can call the function through ``_ffi_api.func_name(*args)``
which will call into the registered global function "mypackage.func_name".
.. code-block:: python
# _ffi_api.py
import tvm_ffi
tvm_ffi.init_ffi_api("mypackage", __name__)
"""
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("tvm."):
prefix = namespace[4:]
else:
prefix = namespace
target_module = sys.modules[target_module_name]
for name in list_global_func_names():
if not name.startswith(prefix):
continue
fname = name[len(prefix) + 1 :]
if fname.find(".") != -1:
continue
f = get_global_func(name)
setattr(f, "__name__", fname)
setattr(target_module, fname, f)
def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
for field in type_info.fields:
name = field.name
if not hasattr(type_cls, name): # skip already defined attributes
setattr(type_cls, name, field.as_property(type_cls))
has_c_init = False
has_shallow_copy = False
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
has_c_init = True
if name == "__ffi_shallow_copy__":
has_shallow_copy = True
# Always override: shallow copy is type-specific and must not be inherited
setattr(type_cls, name, method.as_callable(type_cls))
elif name == "__c_ffi_init__":
# Always override: each type has its own constructor signature
setattr(type_cls, name, method.as_callable(type_cls))
elif not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
if "__init__" not in type_cls.__dict__:
if has_c_init:
setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
elif not issubclass(type_cls, core.PyNativeObject):
setattr(type_cls, "__init__", __init__invalid)
is_container = type_info.type_key in ("ffi.Array", "ffi.Map", "ffi.List", "ffi.Dict")
_setup_copy_methods(type_cls, has_shallow_copy, is_container=is_container)
return type_cls
def _setup_copy_methods(
type_cls: type, has_shallow_copy: bool, *, is_container: bool = False
) -> None:
"""Set up __copy__, __deepcopy__, __replace__ based on copy support."""
if has_shallow_copy:
if "__copy__" not in type_cls.__dict__:
setattr(type_cls, "__copy__", _copy_supported)
if "__deepcopy__" not in type_cls.__dict__:
setattr(type_cls, "__deepcopy__", _deepcopy_supported)
if "__replace__" not in type_cls.__dict__:
setattr(type_cls, "__replace__", _replace_supported)
else:
if "__copy__" not in type_cls.__dict__:
setattr(type_cls, "__copy__", _copy_unsupported)
if "__deepcopy__" not in type_cls.__dict__:
# Containers (Array, Map) support deepcopy via ffi.DeepCopy
# even without __ffi_shallow_copy__
if is_container:
setattr(type_cls, "__deepcopy__", _deepcopy_supported)
else:
setattr(type_cls, "__deepcopy__", _deepcopy_unsupported)
if "__replace__" not in type_cls.__dict__:
setattr(type_cls, "__replace__", _replace_unsupported)
def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("The __init__ method of this class is not implemented.")
def _copy_supported(self: Any) -> Any:
return self.__ffi_shallow_copy__()
def _deepcopy_supported(self: Any, memo: Any = None) -> Any:
return _get_deep_copy_func()(self)
@functools.lru_cache(maxsize=1)
def _get_deep_copy_func() -> core.Function:
return get_global_func("ffi.DeepCopy")
def _replace_supported(self: Any, **kwargs: Any) -> Any:
import copy # noqa: PLC0415
obj = copy.copy(self)
for key, value in kwargs.items():
setattr(obj, key, value)
return obj
def _copy_unsupported(self: Any) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support copy. "
f"The underlying C++ type is not copy-constructible."
)
def _deepcopy_unsupported(self: Any, memo: Any = None) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support deepcopy. "
f"The underlying C++ type is not copy-constructible."
)
def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support replace. "
f"The underlying C++ type is not copy-constructible."
)
def get_registered_type_keys() -> Sequence[str]:
"""Get the list of valid type keys registered to TVM-FFI.
Returns
-------
type_keys
List of valid type keys.
"""
return get_global_func("ffi.GetRegisteredTypeKeys")()
__all__ = [
"get_global_func",
"get_global_func_metadata",
"get_registered_type_keys",
"init_ffi_api",
"list_global_func_names",
"register_global_func",
"register_object",
"remove_global_func",
]