Source code for tvm_ffi.registry

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI registry to register function and objects."""

from __future__ import annotations

import sys
from typing import Any, Callable, Literal, overload

from . import core
from .core import TypeInfo

# whether we simplify skip unknown objects regtistration
_SKIP_UNKNOWN_OBJECTS = False


[docs] def register_object(type_key: str | type | None = None) -> Callable[[type], type] | type: """Register object type. Parameters ---------- type_key : str or cls The type key of the node Examples -------- The following code registers MyObject using type key "test.MyObject" .. code-block:: python @tvm_ffi.register_object("test.MyObject") class MyObject(Object): pass """ def _register(cls: type, object_name: str) -> type: """Register the object type with the FFI core.""" type_index = core._object_type_key_to_index(object_name) if type_index is None: if _SKIP_UNKNOWN_OBJECTS: return cls raise ValueError(f"Cannot find object type index for {object_name}") info = core._register_object_by_index(type_index, cls) _add_class_attrs(type_cls=cls, type_info=info) return cls if isinstance(type_key, str): def _decorator_with_name(cls: type) -> type: return _register(cls, type_key) return _decorator_with_name def _decorator_default(cls: type) -> type: return _register(cls, cls.__name__) if type_key is None: return _decorator_default if isinstance(type_key, type): return _decorator_default(type_key) raise TypeError("type_key must be a string, type, or None")
[docs] def register_global_func( func_name: str | Callable[..., Any], f: Callable[..., Any] | None = None, override: bool = False, ) -> Any: """Register global function. Parameters ---------- func_name : str or function The function name f : function, optional The function to be registered. override: boolean optional Whether override existing entry. Returns ------- fregister : function Register function if f is not specified. Examples -------- .. code-block:: python import tvm_ffi # we can use decorator to register a function @tvm_ffi.register_global_func("mytest.echo") def echo(x): return x # After registering, we can get the function by its name f = tvm_ffi.get_global_func("mytest.echo") assert f(1) == 1 # we can also directly register a function tvm_ffi.register_global_func("mytest.add_one", lambda x: x + 1) f = tvm_ffi.get_global_func("mytest.add_one") assert f(1) == 2 See Also -------- :py:func:`tvm_ffi.get_global_func` :py:func:`tvm_ffi.remove_global_func` """ if callable(func_name): f = func_name func_name = f.__name__ if not isinstance(func_name, str): raise ValueError("expect string function name") def register(myf: Callable[..., Any]) -> Any: """Register the global function with the FFI core.""" return core._register_global_func(func_name, myf, override) if f is not None: return register(f) return register
@overload def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function | None: ... @overload def get_global_func(name: str, allow_missing: Literal[False] = False) -> core.Function: ...
[docs] def get_global_func(name: str, allow_missing: bool = False) -> core.Function | None: """Get a global function by name. Parameters ---------- name : str The name of the global function allow_missing : bool Whether allow missing function or raise an error. Returns ------- func : Function The function to be returned, None if function is missing. See Also -------- :py:func:`tvm_ffi.register_global_func` """ return core._get_global_func(name, allow_missing)
def list_global_func_names() -> list[str]: """Get list of global functions registered. Returns ------- names : list List of global functions names. """ name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")() num_names = name_functor(-1) return [name_functor(i) for i in range(num_names)] def remove_global_func(name: str) -> None: """Remove a global function by name. Parameters ---------- name : str The name of the global function """ get_global_func("ffi.FunctionRemoveGlobal")(name)
[docs] def init_ffi_api(namespace: str, target_module_name: str | None = None) -> None: """Initialize register ffi api functions into a given module. Parameters ---------- namespace : str The namespace of the source registry target_module_name : str The target module name if different from namespace Examples -------- A typical usage pattern is to create a _ffi_api.py file to register the functions under a given module. The following code populates all registered global functions prefixed with ``mypackage.`` into the current module, then we can call the function through ``_ffi_api.func_name(*args)`` which will call into the registered global function "mypackage.func_name". .. code-block:: python # _ffi_api.py import tvm_ffi tvm_ffi.init_ffi_api("mypackage", __name__) """ target_module_name = target_module_name if target_module_name else namespace if namespace.startswith("tvm."): prefix = namespace[4:] else: prefix = namespace target_module = sys.modules[target_module_name] for name in list_global_func_names(): if not name.startswith(prefix): continue fname = name[len(prefix) + 1 :] if fname.find(".") != -1: continue f = get_global_func(name) setattr(f, "__name__", fname) setattr(target_module, fname, f)
def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type: for field in type_info.fields: name = field.name if not hasattr(type_cls, name): # skip already defined attributes setattr(type_cls, name, field.as_property(type_cls)) for method in type_info.methods: name = method.name if name == "__ffi_init__": name = "__c_ffi_init__" if not hasattr(type_cls, name): setattr(type_cls, name, method.as_callable(type_cls)) return type_cls __all__ = [ "get_global_func", "init_ffi_api", "list_global_func_names", "register_global_func", "register_object", "remove_global_func", ]