Program Listing for File registry.h

Program Listing for File registry.h#

Return to documentation for file (tvm/ffi/reflection/registry.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_REFLECTION_REGISTRY_H_
#define TVM_FFI_REFLECTION_REGISTRY_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/type_traits.h>

#include <string>
#include <utility>

namespace tvm {
namespace ffi {
namespace reflection {

struct FieldInfoTrait {};

class DefaultValue : public FieldInfoTrait {
 public:
  explicit DefaultValue(Any value) : value_(value) {}

  TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
    info->default_value = AnyView(value_).CopyToTVMFFIAny();
    info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
  }

 private:
  Any value_;
};

class AttachFieldFlag : public FieldInfoTrait {
 public:
  explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}

  TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
  }
  TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() {
    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore);
  }

  TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; }

 private:
  int32_t flag_;
};

template <typename Class, typename T>
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
  int64_t field_offset_to_class =
      reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
  return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}

class ReflectionDefBase {
 protected:
  template <typename T>
  static int FieldGetter(void* field, TVMFFIAny* result) {
    TVM_FFI_SAFE_CALL_BEGIN();
    *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
    TVM_FFI_SAFE_CALL_END();
  }

  template <typename T>
  static int FieldSetter(void* field, const TVMFFIAny* value) {
    TVM_FFI_SAFE_CALL_BEGIN();
    if constexpr (std::is_same_v<T, Any>) {
      *reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value);
    } else {
      *reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
    }
    TVM_FFI_SAFE_CALL_END();
  }

  template <typename T>
  static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
    TVM_FFI_SAFE_CALL_BEGIN();
    ObjectPtr<T> obj = make_object<T>();
    *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
    TVM_FFI_SAFE_CALL_END();
  }

  template <typename T>
  static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
    TVM_FFI_SAFE_CALL_BEGIN();
    ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
    *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
    TVM_FFI_SAFE_CALL_END();
  }

  template <typename T>
  TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
    if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
      value.Apply(info);
    }
    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
      info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
    }
  }

  template <typename T>
  TVM_FFI_INLINE static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
      info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
    }
  }

  template <typename T>
  TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) {
    if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
      info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
    }
  }

  template <typename Class, typename R, typename... Args>
  TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
    static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
                  "Class must be derived from ObjectRef or Object");
    if constexpr (std::is_base_of_v<ObjectRef, Class>) {
      auto fwrap = [func](Class target, Args... params) -> R {
        // call method pointer
        return (target.*func)(std::forward<Args>(params)...);
      };
      return ffi::Function::FromTyped(fwrap, name);
    }

    if constexpr (std::is_base_of_v<Object, Class>) {
      auto fwrap = [func](const Class* target, Args... params) -> R {
        // call method pointer
        return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
      };
      return ffi::Function::FromTyped(fwrap, name);
    }
  }

  template <typename Class, typename R, typename... Args>
  TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
    static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
                  "Class must be derived from ObjectRef or Object");
    if constexpr (std::is_base_of_v<ObjectRef, Class>) {
      auto fwrap = [func](const Class target, Args... params) -> R {
        // call method pointer
        return (target.*func)(std::forward<Args>(params)...);
      };
      return ffi::Function::FromTyped(fwrap, name);
    }

    if constexpr (std::is_base_of_v<Object, Class>) {
      auto fwrap = [func](const Class* target, Args... params) -> R {
        // call method pointer
        return (target->*func)(std::forward<Args>(params)...);
      };
      return ffi::Function::FromTyped(fwrap, name);
    }
  }

  template <typename Func>
  TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
    return ffi::Function::FromTyped(std::forward<Func>(func), name);
  }
};

class GlobalDef : public ReflectionDefBase {
 public:
  template <typename Func, typename... Extra>
  GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
    RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), std::string(name)),
                 std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
    RegisterFunc(name, ffi::Function::FromPacked(func), std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
    RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
                 std::forward<Extra>(extra)...);
    return *this;
  }

 private:
  template <typename... Extra>
  void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
    TVMFFIMethodInfo info;
    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
    info.doc = TVMFFIByteArray{nullptr, 0};
    info.metadata = TVMFFIByteArray{nullptr, 0};
    info.flags = 0;
    // obtain the method function
    info.method = AnyView(func).CopyToTVMFFIAny();
    // apply method info traits
    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
    TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
  }
};

template <typename Class>
class ObjectDef : public ReflectionDefBase {
 public:
  template <typename... ExtraArgs>
  explicit ObjectDef(ExtraArgs&&... extra_args)
      : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {
    RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
  }

  template <typename T, typename BaseClass, typename... Extra>
  TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
    RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename T, typename BaseClass, typename... Extra>
  TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
    static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
    RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
    RegisterMethod(name, false, std::forward<Func>(func), std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
    RegisterMethod(name, true, std::forward<Func>(func), std::forward<Extra>(extra)...);
    return *this;
  }

 private:
  template <typename... ExtraArgs>
  void RegisterExtraInfo(ExtraArgs&&... extra_args) {
    TVMFFITypeMetadata info;
    info.total_size = sizeof(Class);
    info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
    info.creator = nullptr;
    info.doc = TVMFFIByteArray{nullptr, 0};
    if constexpr (std::is_default_constructible_v<Class>) {
      info.creator = ObjectCreatorDefault<Class>;
    } else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
      info.creator = ObjectCreatorUnsafeInit<Class>;
    }
    // apply extra info traits
    ((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
  }

  template <typename T, typename BaseClass, typename... ExtraArgs>
  void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
                     ExtraArgs&&... extra_args) {
    static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a base class of Class");
    TVMFFIFieldInfo info;
    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
    info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
    // store byte offset and setter, getter
    // so the same setter can be reused for all the same type
    info.offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
    info.size = sizeof(T);
    info.alignment = alignof(T);
    info.flags = 0;
    if (writable) {
      info.flags |= kTVMFFIFieldFlagBitMaskWritable;
    }
    info.getter = FieldGetter<T>;
    info.setter = FieldSetter<T>;
    // initialize default value to nullptr
    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
    info.doc = TVMFFIByteArray{nullptr, 0};
    info.metadata = TVMFFIByteArray{nullptr, 0};
    // apply field info traits
    ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
    // call register
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
  }

  // register a method
  template <typename Func, typename... Extra>
  void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
    TVMFFIMethodInfo info;
    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
    info.doc = TVMFFIByteArray{nullptr, 0};
    info.metadata = TVMFFIByteArray{nullptr, 0};
    info.flags = 0;
    if (is_static) {
      info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
    }
    // obtain the method function
    Function method = GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
    info.method = AnyView(method).CopyToTVMFFIAny();
    // apply method info traits
    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
  }

  int32_t type_index_;
  const char* type_key_;
};

template <typename Class, typename = std::enable_if_t<std::is_base_of_v<Object, Class>>>
class TypeAttrDef : public ReflectionDefBase {
 public:
  template <typename... ExtraArgs>
  explicit TypeAttrDef(ExtraArgs&&... extra_args)
      : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}

  template <typename Func>
  TypeAttrDef& def(const char* name, Func&& func) {
    TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
    ffi::Function ffi_func =
        GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
    TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny();
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
    return *this;
  }

  template <typename T>
  TypeAttrDef& attr(const char* name, T value) {
    TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
    TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny();
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
    return *this;
  }

 private:
  int32_t type_index_;
  const char* type_key_;
};

inline void EnsureTypeAttrColumn(std::string_view name) {
  TVMFFIByteArray name_array = {name.data(), name.size()};
  AnyView any_view(nullptr);
  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array,
                                                 reinterpret_cast<const TVMFFIAny*>(&any_view)));
}

template <typename T, typename... Args>
inline ObjectRef init(Args&&... args) {
  if constexpr (std::is_base_of_v<Object, T>) {
    return ObjectRef(ffi::make_object<T>(std::forward<Args>(args)...));
  } else {
    using U = typename T::ContainerType;
    return ObjectRef(ffi::make_object<U>(std::forward<Args>(args)...));
  }
}

}  // namespace reflection
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_REFLECTION_REGISTRY_H_