Source code for tvm_ffi.registry

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI registry to register function and objects."""

from __future__ import annotations

import json
import sys
from typing import Any, Callable, Literal, overload

from . import core
from .core import TypeInfo

# whether we simplify skip unknown objects regtistration
_SKIP_UNKNOWN_OBJECTS = False


[docs] def register_object(type_key: str | type | None = None) -> Callable[[type], type] | type: """Register object type. Parameters ---------- type_key The type key of the node. It requires ``type_key`` to be registered already on the C++ side. If not specified, the class name will be used. Examples -------- The following code registers MyObject using type key "test.MyObject", if the type key is already registered on the C++ side. .. code-block:: python @tvm_ffi.register_object("test.MyObject") class MyObject(Object): pass """ def _register(cls: type, object_name: str) -> type: """Register the object type with the FFI core.""" type_index = core._object_type_key_to_index(object_name) if type_index is None: if _SKIP_UNKNOWN_OBJECTS: return cls raise ValueError(f"Cannot find object type index for {object_name}") info = core._register_object_by_index(type_index, cls) _add_class_attrs(type_cls=cls, type_info=info) setattr(cls, "__tvm_ffi_type_info__", info) return cls if isinstance(type_key, str): def _decorator_with_name(cls: type) -> type: return _register(cls, type_key) return _decorator_with_name def _decorator_default(cls: type) -> type: return _register(cls, cls.__name__) if type_key is None: return _decorator_default if isinstance(type_key, type): return _decorator_default(type_key) raise TypeError("type_key must be a string, type, or None")
[docs] def register_global_func( func_name: str | Callable[..., Any], f: Callable[..., Any] | None = None, override: bool = False, ) -> Any: """Register global function. Parameters ---------- func_name The function name f The function to be registered. override Whether override existing entry. Returns ------- fregister Register function if f is not specified. Examples -------- .. code-block:: python import tvm_ffi # we can use decorator to register a function @tvm_ffi.register_global_func("mytest.echo") def echo(x): return x # After registering, we can get the function by its name f = tvm_ffi.get_global_func("mytest.echo") assert f(1) == 1 # we can also directly register a function tvm_ffi.register_global_func("mytest.add_one", lambda x: x + 1) f = tvm_ffi.get_global_func("mytest.add_one") assert f(1) == 2 See Also -------- :py:func:`tvm_ffi.get_global_func` :py:func:`tvm_ffi.remove_global_func` """ if callable(func_name): f = func_name func_name = f.__name__ if not isinstance(func_name, str): raise ValueError("expect string function name") def register(myf: Callable[..., Any]) -> Any: """Register the global function with the FFI core.""" return core._register_global_func(func_name, myf, override) if f is not None: return register(f) return register
@overload def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function | None: ... @overload def get_global_func(name: str, allow_missing: Literal[False] = False) -> core.Function: ...
[docs] def get_global_func(name: str, allow_missing: bool = False) -> core.Function | None: """Get a global function by name. Parameters ---------- name The name of the global function allow_missing Whether allow missing function or raise an error. Returns ------- func The function to be returned, ``None`` if function is missing. Examples -------- .. code-block:: python import tvm_ffi @tvm_ffi.register_global_func("demo.echo") def echo(x): return x f = tvm_ffi.get_global_func("demo.echo") assert f(123) == 123 See Also -------- :py:func:`tvm_ffi.register_global_func` """ return core._get_global_func(name, allow_missing)
def list_global_func_names() -> list[str]: """Get list of global functions registered. Returns ------- names List of global functions names. """ name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")() num_names = name_functor(-1) return [name_functor(i) for i in range(num_names)]
[docs] def remove_global_func(name: str) -> None: """Remove a global function by name. Parameters ---------- name The name of the global function. Examples -------- .. code-block:: python import tvm_ffi @tvm_ffi.register_global_func("my.temp") def temp(): return 42 assert tvm_ffi.get_global_func("my.temp", allow_missing=True) is not None tvm_ffi.remove_global_func("my.temp") assert tvm_ffi.get_global_func("my.temp", allow_missing=True) is None See Also -------- :py:func:`tvm_ffi.register_global_func` :py:func:`tvm_ffi.get_global_func` """ get_global_func("ffi.FunctionRemoveGlobal")(name)
[docs] def get_global_func_metadata(name: str) -> dict[str, Any]: """Get metadata (including type schema) for a global function. Parameters ---------- name The name of the global function. Returns ------- metadata A dictionary containing function metadata. The ``type_schema`` field encodes the callable signature. Examples -------- .. code-block:: python import tvm_ffi meta = tvm_ffi.get_global_func_metadata("testing.add_one") print(meta) See Also -------- :py:func:`tvm_ffi.get_global_func` Retrieve a callable for an existing global function. :py:func:`tvm_ffi.register_global_func` Register a Python callable as a global FFI function. """ return json.loads(get_global_func("ffi.GetGlobalFuncMetadata")(name))
[docs] def init_ffi_api(namespace: str, target_module_name: str | None = None) -> None: """Initialize register ffi api functions into a given module. Parameters ---------- namespace The namespace of the source registry target_module_name The target module name if different from namespace Examples -------- A typical usage pattern is to create a _ffi_api.py file to register the functions under a given module. The following code populates all registered global functions prefixed with ``mypackage.`` into the current module, then we can call the function through ``_ffi_api.func_name(*args)`` which will call into the registered global function "mypackage.func_name". .. code-block:: python # _ffi_api.py import tvm_ffi tvm_ffi.init_ffi_api("mypackage", __name__) """ target_module_name = target_module_name if target_module_name else namespace if namespace.startswith("tvm."): prefix = namespace[4:] else: prefix = namespace target_module = sys.modules[target_module_name] for name in list_global_func_names(): if not name.startswith(prefix): continue fname = name[len(prefix) + 1 :] if fname.find(".") != -1: continue f = get_global_func(name) setattr(f, "__name__", fname) setattr(target_module, fname, f)
def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type: for field in type_info.fields: name = field.name if not hasattr(type_cls, name): # skip already defined attributes setattr(type_cls, name, field.as_property(type_cls)) has_c_init = False for method in type_info.methods: name = method.name if name == "__ffi_init__": name = "__c_ffi_init__" has_c_init = True if not hasattr(type_cls, name): setattr(type_cls, name, method.as_callable(type_cls)) if "__init__" not in type_cls.__dict__: if has_c_init: setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__")) elif not issubclass(type_cls, core.PyNativeObject): setattr(type_cls, "__init__", __init__invalid) return type_cls def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None: raise RuntimeError("The __init__ method of this class is not implemented.") __all__ = [ "get_global_func", "get_global_func_metadata", "init_ffi_api", "list_global_func_names", "register_global_func", "register_object", "remove_global_func", ]