Program Listing for File overload.h

Program Listing for File overload.h#

Return to documentation for file (tvm/ffi/reflection/overload.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_EXTRA_OVERLOAD_H
#define TVM_FFI_EXTRA_OVERLOAD_H

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>

#include <cstddef>
#include <cstdint>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>

namespace tvm {
namespace ffi {

namespace details {

struct OverloadBase {
 public:
  // Try Call function pointer type, return true if matched and called
  using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*);

  explicit OverloadBase(int32_t num_args, std::optional<std::string> name)
      : num_args_(num_args),
        name_(name ? std::move(*name) : ""),
        name_ptr_(name ? &this->name_ : nullptr) {}

  virtual void Register(std::unique_ptr<OverloadBase> overload) = 0;
  virtual FnPtr GetTryCallPtr() = 0;
  virtual void GetMismatchMessage(std::ostringstream& os, const AnyView* args,
                                  int32_t num_args) = 0;

  virtual ~OverloadBase() = default;
  OverloadBase(const OverloadBase&) = delete;
  OverloadBase& operator=(const OverloadBase&) = delete;

 public:
  static constexpr int32_t kAllMatched = -1;

  // a fast cache for last matched arg index
  // on 64-bit platform, this is packed in the same 8 byte with num_args_
  int32_t last_mismatch_index_{kAllMatched};

  // some constant helper args
  const int32_t num_args_;
  const std::string name_;
  const std::string* const name_ptr_;
};

template <typename T>
struct CaptureTupleAux;

template <typename... Args>
struct CaptureTupleAux<std::tuple<Args...>> {
  using type = std::tuple<std::optional<std::decay_t<Args>>...>;
};

template <typename Callable>
struct TypedOverload : OverloadBase {
 public:
  static_assert(std::is_same_v<Callable, std::decay_t<Callable>>, "Callable must be value type");

  using FuncInfo = details::FunctionInfo<Callable>;
  using PackedArgs = typename FuncInfo::ArgType;
  using Ret = typename FuncInfo::RetType;
  using CaptureTuple = typename CaptureTupleAux<PackedArgs>::type;
  using OverloadBase::name_;
  using OverloadBase::name_ptr_;
  using typename OverloadBase::FnPtr;

  static constexpr auto kNumArgs = FuncInfo::num_args;
  static constexpr auto kSeq = std::make_index_sequence<kNumArgs>{};

  explicit TypedOverload(const Callable& f, std::optional<std::string> name = std::nullopt)
      : OverloadBase(kNumArgs, std::move(name)), f_(f) {}
  explicit TypedOverload(Callable&& f, std::optional<std::string> name = std::nullopt)
      : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {}

  bool TryCall(const AnyView* args, int32_t num_args, Any* rv) {
    if (num_args != kNumArgs) return false;
    CaptureTuple captures{};
    if (!TrySetAux(kSeq, captures, args)) return false;
    // now all captures are set
    if constexpr (std::is_same_v<Ret, void>) {
      CallAux(kSeq, captures);
      return true;
    } else {
      *rv = CallAux(kSeq, captures);
      return true;
    }
  }

  void Register(std::unique_ptr<OverloadBase> overload) override {
    TVM_FFI_ICHECK(false) << "This should never be called.";
  }

  FnPtr GetTryCallPtr() final {
    // lambda without a capture can be converted to function pointer
    return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any* rv) -> bool {
      return static_cast<TypedOverload<Callable>*>(base)->TryCall(args, num_args, rv);
    };
  }

  void GetMismatchMessage(std::ostringstream& os, const AnyView* args, int32_t num_args) final {
    FGetFuncSignature f_sig = FuncInfo::Sig;
    if (num_args != kNumArgs) {
      os << "Mismatched number of arguments when calling: `" << name_ << " "
         << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << kNumArgs << " arguments";
    } else {
      GetMismatchMessageAux<0>(os, args, num_args);
    }
  }

 private:
  template <std::size_t I>
  void GetMismatchMessageAux(std::ostringstream& os, const AnyView* args, int32_t num_args) {
    if constexpr (I < kNumArgs) {
      if (this->last_mismatch_index_ == static_cast<int32_t>(I)) {
        TVMFFIAny any_data = args[I].CopyToTVMFFIAny();
        FGetFuncSignature f_sig = FuncInfo::Sig;
        using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
        os << "Mismatched type on argument #" << I << " when calling: `" << name_ << " "
           << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected `" << Type2Str<Type>::v()
           << "` but got `" << TypeTraits<Type>::GetMismatchTypeInfo(&any_data) << '`';
      } else {
        GetMismatchMessageAux<I + 1>(os, args, num_args);
      }
    }
    // end of recursion
  }

  template <std::size_t... I>
  Ret CallAux(std::index_sequence<I...>, CaptureTuple& tuple) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    return f_(static_cast<std::tuple_element_t<I, PackedArgs>>(std::move(*std::get<I>(tuple)))...);
  }

  template <std::size_t... I>
  bool TrySetAux(std::index_sequence<I...>, CaptureTuple& tuple, const AnyView* args) {
    return (TrySetOne<I>(tuple, args) && ...);
  }

  template <std::size_t I>
  bool TrySetOne(CaptureTuple& tuple, const AnyView* args) {
    using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
    auto& capture = std::get<I>(tuple);
    if constexpr (std::is_same_v<Type, AnyView>) {
      capture = args[I];
      return true;
    } else if constexpr (std::is_same_v<Type, Any>) {
      capture = Any(args[I]);
      return true;
    } else {
      capture = args[I].template try_cast<Type>();
      if (capture.has_value()) return true;
      // slow path: record the last mismatch index
      this->last_mismatch_index_ = static_cast<int32_t>(I);
      return false;
    }
  }

 protected:
  Callable f_;
};

template <typename Callable>
inline auto CreateNewOverload(Callable&& f, std::string name) {
  using Type = TypedOverload<std::decay_t<Callable>>;
  return std::make_unique<Type>(std::forward<Callable>(f), std::move(name));
}

template <typename Callable>
struct OverloadedFunction : TypedOverload<Callable> {
 public:
  using TypedBase = TypedOverload<Callable>;
  using OverloadBase::name_;
  using OverloadBase::name_ptr_;
  using TypedBase::GetTryCallPtr;
  using TypedBase::kNumArgs;
  using TypedBase::kSeq;
  using TypedBase::TypedBase;  // constructors
  using typename OverloadBase::FnPtr;
  using typename TypedBase::Ret;

  void Register(std::unique_ptr<OverloadBase> overload) final {
    const auto fptr = overload->GetTryCallPtr();
    overloads_.emplace_back(std::move(overload), fptr);
  }

  void operator()(const AnyView* args, int32_t num_args, Any* rv) {
    // fast path: only add a little overhead when no overloads
    if (overloads_.size() == 0) {
      return unpack_call<Ret>(kSeq, name_ptr_, f_, args, num_args, rv);
    }

    // this can be inlined by compiler, don't worry
    if (this->TryCall(args, num_args, rv)) return;

    // virtual calls cannot be inlined, so we fast check the num_args first
    // we also de-virtualize the fptr to reduce one more indirection
    for (const auto& [overload, fptr] : overloads_) {
      if (overload->num_args_ != num_args) continue;
      if (fptr(overload.get(), args, num_args, rv)) return;
    }

    this->HandleOverloadFailure(args, num_args);
  }

 private:
  void HandleOverloadFailure(const AnyView* args, int32_t num_args) {
    std::ostringstream oss;
    int32_t i = 0;
    oss << "Overload #" << i++ << ": ";
    this->GetMismatchMessage(oss, args, num_args);
    for (const auto& [overload, _] : overloads_) {
      oss << "\nOverload #" << i++ << ": ";
      overload->GetMismatchMessage(oss, args, num_args);
    }
    TVM_FFI_THROW(TypeError) << "No matching overload found when calling: `" << name_ << "` with "
                             << num_args << " arguments:\n"
                             << std::move(oss).str();
  }
  using TypedBase::f_;
  std::vector<std::pair<std::unique_ptr<OverloadBase>, FnPtr>> overloads_;
};

}  // namespace details

namespace reflection {

template <typename Class>
class OverloadObjectDef : private ObjectDef<Class> {
 public:
  using Super = ObjectDef<Class>;
  template <typename... ExtraArgs>
  explicit OverloadObjectDef(ExtraArgs&&... extra_args)
      : Super(std::forward<ExtraArgs>(extra_args)...) {}

  template <typename T, typename BaseClass, typename... Extra>
  TVM_FFI_INLINE OverloadObjectDef& def_ro(const char* name, T BaseClass::* field_ptr,
                                           Extra&&... extra) {
    Super::def_ro(name, field_ptr, std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename T, typename BaseClass, typename... Extra>
  TVM_FFI_INLINE OverloadObjectDef& def_rw(const char* name, T BaseClass::* field_ptr,
                                           Extra&&... extra) {
    Super::def_rw(name, field_ptr, std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  TVM_FFI_INLINE OverloadObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
    RegisterMethod(name, false, std::forward<Func>(func), std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename Func, typename... Extra>
  TVM_FFI_INLINE OverloadObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
    RegisterMethod(name, true, std::forward<Func>(func), std::forward<Extra>(extra)...);
    return *this;
  }

  template <typename... Args, typename... Extra>
  TVM_FFI_INLINE OverloadObjectDef& def([[maybe_unused]] init<Args...> init_func,
                                        Extra&&... extra) {
    RegisterMethod(kInitMethodName, true, &init<Args...>::template execute<Class>,
                   std::forward<Extra>(extra)...);
    return *this;
  }

 private:
  using ReflectionDefBase::ApplyExtraInfoTrait;
  using ReflectionDefBase::WrapFunction;
  using Super::kInitMethodName;
  using Super::type_index_;
  using Super::type_key_;

  template <typename Func>
  static auto GetOverloadMethod(std::string name, Func&& func) {
    using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
    using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
    return ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
                                                        std::move(name));
  }

  template <typename Func>
  static auto NewOverload(std::string name, Func&& func) {
    return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)), std::move(name));
  }

  template <typename... ExtraArgs>
  void RegisterExtraInfo(ExtraArgs&&... extra_args) {
    TVMFFITypeMetadata info;
    info.total_size = sizeof(Class);
    info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
    info.creator = nullptr;
    info.doc = TVMFFIByteArray{nullptr, 0};
    if constexpr (std::is_default_constructible_v<Class>) {
      info.creator = ReflectionDefBase::ObjectCreatorDefault<Class>;
    } else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
      info.creator = ReflectionDefBase::ObjectCreatorUnsafeInit<Class>;
    }
    // apply extra info traits
    ((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
  }

  template <typename T, typename BaseClass, typename... ExtraArgs>
  void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
                     ExtraArgs&&... extra_args) {
    static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a base class of Class");
    FieldInfoBuilder info;
    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
    info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
    // store byte offset and setter, getter
    // so the same setter can be reused for all the same type
    info.offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
    info.size = sizeof(T);
    info.alignment = alignof(T);
    info.flags = 0;
    if (writable) {
      info.flags |= kTVMFFIFieldFlagBitMaskWritable;
    }
    info.getter = ReflectionDefBase::FieldGetter<T>;
    info.setter = ReflectionDefBase::FieldSetter<T>;
    // initialize default value to nullptr
    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
    info.doc = TVMFFIByteArray{nullptr, 0};
    info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
    // apply field info traits
    ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
    // call register
    std::string metadata_str = Metadata::ToJSON(info.metadata_);
    info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
  }

  // register a method
  template <typename Func, typename... Extra>
  void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
    MethodInfoBuilder info;
    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
    info.doc = TVMFFIByteArray{nullptr, 0};
    info.flags = 0;
    if (is_static) {
      info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
    }

    auto method_name = std::string(type_key_) + "." + name;

    // if an overload method exists, register to existing overload function
    if (const auto overload_it = registered_fields_.find(name);
        overload_it != registered_fields_.end()) {
      details::OverloadBase* overload_ptr = overload_it->second;
      return overload_ptr->Register(NewOverload(std::move(method_name), std::forward<Func>(func)));
    }

    // first time registering overload method
    auto [method, overload_ptr] =
        GetOverloadMethod(std::move(method_name), std::forward<Func>(func));
    registered_fields_.try_emplace(name, overload_ptr);

    info.method = AnyView(method).CopyToTVMFFIAny();
    info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema());
    // apply method info traits
    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
    std::string metadata_str = Metadata::ToJSON(info.metadata_);
    info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
  }

  std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
};

}  // namespace reflection
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_EXTRA_OVERLOAD_H