Program Listing for File function.h

Program Listing for File function.h#

Return to documentation for file (tvm/ffi/function.h)

  TVM_FFI_SAFE_CALL_BEGIN();
  // c++ code region here
  TVM_FFI_SAFE_CALL_END();
}
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index));
  return x + 1;
}

// Expose the function as "AddOne"
TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_);

// Expose the function as "SubOne"
TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) {
  return x - 1;
});
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_FUNCTION_H_
#define TVM_FFI_FUNCTION_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function_details.h>

#include <functional>
#include <string>
#include <utility>
#include <vector>

namespace tvm {
namespace ffi {

#define TVM_FFI_SAFE_CALL_BEGIN() \
  try {                           \
  (void)0

#define TVM_FFI_SAFE_CALL_END()                                                                \
  return 0;                                                                                    \
  }                                                                                            \
  catch (const ::tvm::ffi::Error& err) {                                                       \
    ::tvm::ffi::details::SetSafeCallRaised(err);                                               \
    return -1;                                                                                 \
  }                                                                                            \
  catch (const ::tvm::ffi::EnvErrorAlreadySet&) {                                              \
    return -2;                                                                                 \
  }                                                                                            \
  catch (const std::exception& ex) {                                                           \
    ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \
    return -1;                                                                                 \
  }                                                                                            \
  TVM_FFI_UNREACHABLE()

#define TVM_FFI_CHECK_SAFE_CALL(func)                      \
  {                                                        \
    int ret_code = (func);                                 \
    if (ret_code != 0) {                                   \
      if (ret_code == -2) {                                \
        throw ::tvm::ffi::EnvErrorAlreadySet();            \
      }                                                    \
      throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \
    }                                                      \
  }

class FunctionObj : public Object, public TVMFFIFunctionCell {
 public:
  typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*);
  using TVMFFIFunctionCell::cpp_call;
  using TVMFFIFunctionCell::safe_call;
  TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const {
    // if cpp_call is set, use it to call the function, otherwise, redirect to safe_call
    // use conditional expression here so the select is branchless
    FCall call_ptr =
        this->cpp_call ? reinterpret_cast<FCall>(this->cpp_call) : CppCallDedirectToSafeCall;
    (*call_ptr)(this, args, num_args, result);
  }
  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object);

 protected:
  FunctionObj() {}
  friend class Function;

 private:
  static void CppCallDedirectToSafeCall(const FunctionObj* func, const AnyView* args,
                                        int32_t num_args, Any* rv) {
    FunctionObj* self = static_cast<FunctionObj*>(const_cast<FunctionObj*>(func));
    TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast<const TVMFFIAny*>(args),
                                            num_args, reinterpret_cast<TVMFFIAny*>(rv)));
  }
};

namespace details {
template <typename TCallable>
class FunctionObjImpl : public FunctionObj {
 public:
  using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
  using TSelf = FunctionObjImpl<TCallable>;
  explicit FunctionObjImpl(TCallable callable) : callable_(callable) {
    this->safe_call = SafeCall;
    this->cpp_call = reinterpret_cast<void*>(CppCall);
  }

 private:
  // implementation of call
  static void CppCall(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) {
    (static_cast<const TSelf*>(func))->callable_(args, num_args, result);
  }
  // Implementing safe call style
  static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) {
    TVM_FFI_SAFE_CALL_BEGIN();
    TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
    FunctionObj* self = static_cast<FunctionObj*>(func);
    reinterpret_cast<FCall>(self->cpp_call)(self, reinterpret_cast<const AnyView*>(args), num_args,
                                            reinterpret_cast<Any*>(result));
    TVM_FFI_SAFE_CALL_END();
  }

  mutable TStorage callable_;
};

class ExternCFunctionObjNullHandleImpl : public FunctionObj {
 public:
  explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) {
    this->safe_call = safe_call;
    this->cpp_call = nullptr;
  }
};

class ExternCFunctionObjImpl : public FunctionObj {
 public:
  ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self))
      : self_(self), safe_call_(safe_call), deleter_(deleter) {
    this->safe_call = SafeCall;
    this->cpp_call = nullptr;
  }

  ~ExternCFunctionObjImpl() {
    if (deleter_) deleter_(self_);
  }

 private:
  static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) {
    ExternCFunctionObjImpl* self = reinterpret_cast<ExternCFunctionObjImpl*>(func);
    return self->safe_call_(self->self_, args, num_args, rv);
  }

  void* self_;
  TVMFFISafeCallType safe_call_;
  void (*deleter_)(void* self);
};

// Helper class to set packed arguments
class PackedArgsSetter {
 public:
  explicit PackedArgsSetter(AnyView* args) : args_(args) {}

  // NOTE: setter needs to be very carefully designed
  // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue)
  // that is why we need T&& and std::forward here
  template <typename T>
  TVM_FFI_INLINE void operator()(size_t i, T&& value) const {
    args_[i].operator=(std::forward<T>(value));
  }

 private:
  AnyView* args_;
};
}  // namespace details

class PackedArgs {
 public:
  PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {}

  int size() const { return size_; }

  const AnyView* data() const { return data_; }

  PackedArgs Slice(int begin, int end = -1) const {
    if (end == -1) {
      end = size_;
    }
    return PackedArgs(data_ + begin, end - begin);
  }

  AnyView operator[](int i) const { return data_[i]; }

  template <typename... Args>
  TVM_FFI_INLINE static void Fill(AnyView* data, Args&&... args) {
    details::for_each(details::PackedArgsSetter(data), std::forward<Args>(args)...);
  }

 private:
  const AnyView* data_;
  int32_t size_;
};

class Function : public ObjectRef {
 public:
  Function(std::nullptr_t) : ObjectRef(nullptr) {}  // NOLINT(*)
  template <typename TCallable>
  explicit Function(TCallable packed_call) {
    *this = FromPacked(packed_call);
  }
  template <typename TCallable>
  static Function FromPacked(TCallable packed_call) {
    static_assert(
        std::is_convertible_v<TCallable, std::function<void(const AnyView*, int32_t, Any*)>> ||
            std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>,
        "tvm::ffi::Function::FromPacked requires input function signature to match packed func "
        "format");
    if constexpr (std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>) {
      auto wrapped_call = [packed_call](const AnyView* args, int32_t num_args,
                                        Any* rv) mutable -> void {
        PackedArgs args_pack(args, num_args);
        packed_call(args_pack, rv);
      };
      return FromPackedInternal(wrapped_call);
    } else {
      return FromPackedInternal(packed_call);
    }
  }

  static Function FromExternC(void* self, TVMFFISafeCallType safe_call,
                              void (*deleter)(void* self)) {
    // the other function coems from a different library
    Function func;
    if (self == nullptr && deleter == nullptr) {
      func.data_ = make_object<details::ExternCFunctionObjNullHandleImpl>(safe_call);
    } else {
      func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter);
    }
    return func;
  }
  static std::optional<Function> GetGlobal(std::string_view name) {
    TVMFFIObjectHandle handle;
    TVMFFIByteArray name_arr{name.data(), name.size()};
    TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle));
    if (handle != nullptr) {
      return Function(
          details::ObjectUnsafe::ObjectPtrFromOwned<FunctionObj>(static_cast<Object*>(handle)));
    } else {
      return std::nullopt;
    }
  }

  static std::optional<Function> GetGlobal(const std::string& name) {
    return GetGlobal(std::string_view(name.data(), name.length()));
  }

  static std::optional<Function> GetGlobal(const String& name) {
    return GetGlobal(std::string_view(name.data(), name.length()));
  }

  static std::optional<Function> GetGlobal(const char* name) {
    return GetGlobal(std::string_view(name));
  }
  static Function GetGlobalRequired(std::string_view name) {
    std::optional<Function> res = GetGlobal(name);
    if (!res.has_value()) {
      TVM_FFI_THROW(ValueError) << "Function " << name << " not found";
    }
    return *res;
  }

  static Function GetGlobalRequired(const std::string& name) {
    return GetGlobalRequired(std::string_view(name.data(), name.length()));
  }

  static Function GetGlobalRequired(const String& name) {
    return GetGlobalRequired(std::string_view(name.data(), name.length()));
  }

  static Function GetGlobalRequired(const char* name) {
    return GetGlobalRequired(std::string_view(name));
  }
  static void SetGlobal(std::string_view name, Function func, bool override = false) {
    TVMFFIByteArray name_arr{name.data(), name.size()};
    TVM_FFI_CHECK_SAFE_CALL(
        TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override));
  }
  static std::vector<String> ListGlobalNames() {
    Function fname_functor =
        GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast<Function>();
    std::vector<String> names;
    int len = fname_functor(-1).cast<int>();
    for (int i = 0; i < len; ++i) {
      names.push_back(fname_functor(i).cast<String>());
    }
    return names;
  }
  static void RemoveGlobal(const String& name) {
    static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal");
    fremove(name);
  }
  template <typename TCallable>
  static Function FromTyped(TCallable callable) {
    using FuncInfo = details::FunctionInfo<TCallable>;
    auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
      details::unpack_call<typename FuncInfo::RetType>(
          std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, args, num_args, rv);
    };
    return FromPackedInternal(call_packed);
  }
  template <typename TCallable>
  static Function FromTyped(TCallable callable, std::string name) {
    using FuncInfo = details::FunctionInfo<TCallable>;
    auto call_packed = [callable, name](const AnyView* args, int32_t num_args,
                                        Any* rv) mutable -> void {
      details::unpack_call<typename FuncInfo::RetType>(
          std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, args, num_args, rv);
    };
    return FromPackedInternal(call_packed);
  }
  template <typename... Args>
  TVM_FFI_INLINE Any operator()(Args&&... args) const {
    const int kNumArgs = sizeof...(Args);
    const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
    AnyView args_pack[kArraySize];
    PackedArgs::Fill(args_pack, std::forward<Args>(args)...);
    Any result;
    static_cast<FunctionObj*>(data_.get())->CallPacked(args_pack, kNumArgs, &result);
    return result;
  }
  TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const {
    static_cast<FunctionObj*>(data_.get())->CallPacked(args, num_args, result);
  }
  TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const {
    static_cast<FunctionObj*>(data_.get())->CallPacked(args.data(), args.size(), result);
  }

  TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; }
  TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; }

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj);

  class Registry;

 private:
  template <typename TCallable>
  static Function FromPackedInternal(TCallable packed_call) {
    using ObjType = typename details::FunctionObjImpl<TCallable>;
    Function func;
    func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call));
    return func;
  }
};

template <typename FType>
class TypedFunction;

template <typename R, typename... Args>
class TypedFunction<R(Args...)> {
 public:
  using TSelf = TypedFunction<R(Args...)>;
  TypedFunction() {}
  TypedFunction(std::nullptr_t null) {}  // NOLINT(*)
  TypedFunction(Function packed) : packed_(packed) {}  // NOLINT(*)
  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
                                  FLambda, std::function<R(Args...)>>::value>::type>
  TypedFunction(FLambda typed_lambda, std::string name) {  // NOLINT(*)
    packed_ = Function::FromTyped(typed_lambda, name);
  }
  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
                                  FLambda, std::function<R(Args...)>>::value>::type>
  TypedFunction(const FLambda& typed_lambda) {  // NOLINT(*)
    packed_ = Function::FromTyped(typed_lambda);
  }
  template <typename FLambda, typename = typename std::enable_if<
                                  std::is_convertible<FLambda,
                                                      std::function<R(Args...)>>::value>::type>
  TSelf& operator=(FLambda typed_lambda) {  // NOLINT(*)
    packed_ = Function::FromTyped(typed_lambda);
    return *this;
  }
  TSelf& operator=(Function packed) {
    packed_ = std::move(packed);
    return *this;
  }
  TVM_FFI_INLINE R operator()(Args... args) const {
    if constexpr (std::is_same_v<R, void>) {
      packed_(std::forward<Args>(args)...);
    } else {
      Any res = packed_(std::forward<Args>(args)...);
      if constexpr (std::is_same_v<R, Any>) {
        return res;
      } else {
        return std::move(res).cast<R>();
      }
    }
  }
  operator Function() const { return packed(); }
  const Function& packed() const& { return packed_; }
  constexpr Function&& packed() && { return std::move(packed_); }
  bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
  bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }

 private:
  Function packed_;
};

template <typename FType>
inline constexpr bool use_default_type_traits_v<TypedFunction<FType>> = false;

template <typename FType>
struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase {
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction;

  TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction<FType>& src, TVMFFIAny* result) {
    TypeTraits<Function>::CopyToAnyView(src.packed(), result);
  }

  TVM_FFI_INLINE static void MoveToAny(TypedFunction<FType> src, TVMFFIAny* result) {
    TypeTraits<Function>::MoveToAny(std::move(src.packed()), result);
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    return src->type_index == TypeIndex::kTVMFFIFunction;
  }

  TVM_FFI_INLINE static TypedFunction<FType> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
    return TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src));
  }

  TVM_FFI_INLINE static std::optional<TypedFunction<FType>> TryCastFromAnyView(
      const TVMFFIAny* src) {
    std::optional<Function> opt = TypeTraits<Function>::TryCastFromAnyView(src);
    if (opt.has_value()) {
      return TypedFunction<FType>(*std::move(opt));
    } else {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo<FType>::Sig(); }
};

inline int32_t TypeKeyToIndex(std::string_view type_key) {
  int32_t type_index;
  TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
  return type_index;
}

#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function)                            \
  extern "C" {                                                                         \
  TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, const TVMFFIAny* args,     \
                                                int32_t num_args, TVMFFIAny* result) { \
    TVM_FFI_SAFE_CALL_BEGIN();                                                         \
    using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;            \
    static std::string name = #ExportName;                                             \
    ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(                      \
        std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,               \
        reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,                  \
        reinterpret_cast<::tvm::ffi::Any*>(result));                                   \
    TVM_FFI_SAFE_CALL_END();                                                           \
  }                                                                                    \
  }
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_FUNCTION_H_