Program Listing for File registry.h#
↰ Return to documentation for file (tvm/ffi/reflection/registry.h)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_REFLECTION_REGISTRY_H_
#define TVM_FFI_REFLECTION_REGISTRY_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>
#include <iterator>
#include <optional>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace tvm {
namespace ffi {
namespace reflection {
using _MetadataType = std::vector<std::pair<String, Any>>; // NOLINT(bugprone-reserved-identifier)
struct FieldInfoBuilder : public TVMFFIFieldInfo {
_MetadataType metadata_;
};
struct MethodInfoBuilder : public TVMFFIMethodInfo {
_MetadataType metadata_;
};
struct InfoTrait {};
class Metadata : public InfoTrait {
public:
Metadata(std::initializer_list<std::pair<String, Any>> dict) : dict_(dict) {}
inline void Apply(FieldInfoBuilder* info) const { this->Apply(&info->metadata_); }
inline void Apply(MethodInfoBuilder* info) const { this->Apply(&info->metadata_); }
private:
friend class GlobalDef;
template <typename T>
friend class ObjectDef;
inline void Apply(_MetadataType* out) const {
std::copy(std::make_move_iterator(dict_.begin()), std::make_move_iterator(dict_.end()),
std::back_inserter(*out));
}
static std::string ToJSON(const _MetadataType& metadata) {
using ::tvm::ffi::details::StringObj;
std::ostringstream os;
os << "{";
bool first = true;
for (const auto& [key, value] : metadata) {
if (!first) {
os << ",";
}
os << "\"" << key << "\":";
if (std::optional<int> v = value.as<int>()) {
os << *v;
} else if (std::optional<bool> v = value.as<bool>()) {
os << (*v ? "true" : "false");
} else if (std::optional<String> v = value.as<String>()) {
String escaped = EscapeString(*v);
os << escaped.c_str();
} else {
TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or string, but on key `"
<< key << "`, the type is " << value.GetTypeKey();
}
first = false;
}
os << "}";
return os.str();
}
std::vector<std::pair<String, Any>> dict_;
};
class DefaultValue : public InfoTrait {
public:
explicit DefaultValue(Any value) : value_(std::move(value)) {}
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
info->default_value = AnyView(value_).CopyToTVMFFIAny();
info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
}
private:
Any value_;
};
class AttachFieldFlag : public InfoTrait {
public:
explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
}
TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() {
return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore);
}
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; }
private:
int32_t flag_;
};
template <typename Class, typename T>
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}
class ReflectionDefBase {
protected:
template <typename T>
static int FieldGetter(void* field, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
TVM_FFI_SAFE_CALL_END();
}
template <typename T>
static int FieldSetter(void* field, const TVMFFIAny* value) {
TVM_FFI_SAFE_CALL_BEGIN();
if constexpr (std::is_same_v<T, Any>) {
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value);
} else {
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
}
TVM_FFI_SAFE_CALL_END();
}
template <typename T>
static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>();
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}
template <typename T>
static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}
template <typename T>
TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder* info, const T& value) {
if constexpr (std::is_base_of_v<InfoTrait, std::decay_t<T>>) {
value.Apply(info);
}
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}
template <typename T>
TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder* info, const T& value) {
if constexpr (std::is_base_of_v<InfoTrait, std::decay_t<T>>) {
value.Apply(info);
}
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}
template <typename T>
TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) {
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}
template <typename Class, typename R, typename... Args>
TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
auto fwrap = [func](Class target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, std::move(name));
}
if constexpr (std::is_base_of_v<Object, Class>) {
auto fwrap = [func](const Class* target, Args... params) -> R {
// call method pointer
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, std::move(name));
}
}
template <typename Class, typename R, typename... Args>
TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
auto fwrap = [func](const Class& target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, std::move(name));
}
if constexpr (std::is_base_of_v<Object, Class>) {
auto fwrap = [func](const Class* target, Args... params) -> R {
// call method pointer
return (target->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, std::move(name));
}
}
template <typename Func>
TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
return ffi::Function::FromTyped(std::forward<Func>(func), std::move(name));
}
};
class GlobalDef : public ReflectionDefBase {
public:
template <typename Func, typename... Extra>
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), std::string(name)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
}
template <typename Func, typename... Extra>
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl<Function>::v(),
std::forward<Extra>(extra)...);
return *this;
}
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
}
private:
template <typename... Extra> // NOLINTNEXTLINE(performance-unnecessary-value-param)
void RegisterFunc(const char* name, ffi::Function func, String type_schema, Extra&&... extra) {
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
info.flags = 0;
info.method = AnyView(func).CopyToTVMFFIAny();
info.metadata_.emplace_back("type_schema", type_schema);
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
std::string metadata_str = Metadata::ToJSON(info.metadata_);
info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
}
};
template <typename... Args>
struct init {
// Allow ObjectDef to access the execute function
template <typename Class>
friend class ObjectDef;
constexpr init() noexcept = default;
private:
template <typename Class>
static inline ObjectRef execute(Args&&... args) {
return ObjectRef(ffi::make_object<Class>(std::forward<Args>(args)...));
}
};
template <typename Class>
class ObjectDef : public ReflectionDefBase {
public:
template <typename... ExtraArgs>
explicit ObjectDef(ExtraArgs&&... extra_args)
: type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {
RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
}
template <typename T, typename BaseClass, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
return *this;
}
template <typename T, typename BaseClass, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
return *this;
}
template <typename Func, typename... Extra>
TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
RegisterMethod(name, false, std::forward<Func>(func), std::forward<Extra>(extra)...);
return *this;
}
template <typename Func, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
RegisterMethod(name, true, std::forward<Func>(func), std::forward<Extra>(extra)...);
return *this;
}
template <typename... Args, typename... Extra>
TVM_FFI_INLINE ObjectDef& def([[maybe_unused]] init<Args...> init_func, Extra&&... extra) {
RegisterMethod(kInitMethodName, true, &init<Args...>::template execute<Class>,
std::forward<Extra>(extra)...);
return *this;
}
private:
template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
TVMFFITypeMetadata info;
info.total_size = sizeof(Class);
info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
info.creator = nullptr;
info.doc = TVMFFIByteArray{nullptr, 0};
if constexpr (std::is_default_constructible_v<Class>) {
info.creator = ObjectCreatorDefault<Class>;
} else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
info.creator = ObjectCreatorUnsafeInit<Class>;
}
// apply extra info traits
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
}
template <typename T, typename BaseClass, typename... ExtraArgs>
void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
ExtraArgs&&... extra_args) {
static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a base class of Class");
FieldInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
// store byte offset and setter, getter
// so the same setter can be reused for all the same type
info.offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
info.size = sizeof(T);
info.alignment = alignof(T);
info.flags = 0;
if (writable) {
info.flags |= kTVMFFIFieldFlagBitMaskWritable;
}
info.getter = FieldGetter<T>;
info.setter = FieldSetter<T>;
// initialize default value to nullptr
info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
std::string metadata_str = Metadata::ToJSON(info.metadata_);
info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
}
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
info.flags = 0;
if (is_static) {
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
}
// obtain the method function
Function method = GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
info.method = AnyView(method).CopyToTVMFFIAny();
info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema());
// apply method info traits
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
std::string metadata_str = Metadata::ToJSON(info.metadata_);
info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}
int32_t type_index_;
const char* type_key_;
static constexpr const char* kInitMethodName = "__ffi_init__";
};
template <typename Class, typename = std::enable_if_t<std::is_base_of_v<Object, Class>>>
class TypeAttrDef : public ReflectionDefBase {
public:
template <typename... ExtraArgs>
explicit TypeAttrDef(ExtraArgs&&... extra_args)
: type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}
template <typename Func>
TypeAttrDef& def(const char* name, Func&& func) {
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
ffi::Function ffi_func =
GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
return *this;
}
template <typename T>
TypeAttrDef& attr(const char* name, T value) {
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
return *this;
}
private:
int32_t type_index_;
const char* type_key_;
};
inline void EnsureTypeAttrColumn(std::string_view name) {
TVMFFIByteArray name_array = {name.data(), name.size()};
AnyView any_view(nullptr);
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array,
reinterpret_cast<const TVMFFIAny*>(&any_view)));
}
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_REGISTRY_H_