Program Listing for File stl.h

Program Listing for File stl.h#

Return to documentation for file (tvm/ffi/extra/stl.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_EXTRA_STL_H_
#define TVM_FFI_EXTRA_STL_H_

#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/type_traits.h>

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <functional>
#include <iterator>
#include <map>
#include <optional>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "tvm/ffi/function.h"

namespace tvm {
namespace ffi {
namespace details {

struct STLTypeMismatch : public std::exception {
  const char* what() const noexcept override { return "STL type mismatch"; }
};

struct STLTypeTrait : public TypeTraitsBase {
 public:
  static constexpr bool storage_enabled = false;

 protected:
  template <typename T>
  TVM_FFI_INLINE static void MoveToAnyImpl(ObjectPtr<T>&& src, TVMFFIAny* result) {
    TVMFFIObject* obj_ptr = ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(src));
    result->type_index = obj_ptr->type_index;
    result->zero_padding = 0;
    TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
    result->v_obj = obj_ptr;
  }

  template <typename T>
  TVM_FFI_INLINE static ObjectPtr<T> CopyFromAnyImpl(const TVMFFIAny* src) {
    return ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj);
  }

  template <typename T>
  TVM_FFI_INLINE static T ConstructFromAny(const Any& value) {
    using TypeTrait = TypeTraits<T>;
    if constexpr (std::is_same_v<T, Any>) {
      return value;
    } else {
      auto opt = TypeTrait::TryCastFromAnyView(AnyUnsafe::TVMFFIAnyPtrFromAny(value));
      if (!opt.has_value()) {
        throw STLTypeMismatch{};
      }
      return std::move(*opt);
    }
  }
};

struct ListTemplate {};
struct MapTemplate {};

}  // namespace details

template <>
struct TypeTraits<details::ListTemplate> : public details::STLTypeTrait {
 public:
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;

 private:
  template <std::size_t... Is, typename Tuple>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToTupleImpl(std::index_sequence<Is...>,
                                                            Tuple&& src) {
    auto array = ArrayObj::Empty(static_cast<std::int64_t>(sizeof...(Is)));
    auto dst = array->MutableBegin();
    // increase size after each new to ensure exception safety
    std::apply(
        [&](auto&&... elems) {
          ((::new (dst++) Any(std::forward<decltype(elems)>(elems)), array->size_++), ...);
        },
        std::forward<Tuple>(src));
    return array;
  }

  template <typename Iter>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArrayImpl(Iter src, std::size_t size) {
    auto array = ArrayObj::Empty(static_cast<std::int64_t>(size));
    auto dst = array->MutableBegin();
    // increase size after each new to ensure exception safety
    for (std::size_t i = 0; i < size; ++i) {
      ::new (dst++) Any(*(src++));
      array->size_++;
    }
    return array;
  }

 protected:
  template <typename Tuple>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToTuple(const Tuple& src) {
    return CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{}, src);
  }

  template <typename Tuple>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToTuple(Tuple&& src) {
    return CopyToTupleImpl(std::make_index_sequence<std::tuple_size_v<Tuple>>{},
                           std::forward<Tuple>(src));
  }

  template <typename Range>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> CopyToArray(const Range& src) {
    return CopyToArrayImpl(std::begin(src), std::size(src));
  }

  template <typename Range>
  TVM_FFI_INLINE static ObjectPtr<ArrayObj> MoveToArray(Range&& src) {
    return CopyToArrayImpl(std::make_move_iterator(std::begin(src)), std::size(src));
  }
};

template <>
struct TypeTraits<details::MapTemplate> : public details::STLTypeTrait {
 public:
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap;

 protected:
  template <typename MapType>
  TVM_FFI_INLINE static ObjectPtr<Object> CopyToMap(const MapType& src) {
    return MapObj::CreateFromRange(std::begin(src), std::end(src));
  }

  template <typename MapType>
  TVM_FFI_INLINE static ObjectPtr<Object> MoveToMap(MapType&& src) {
    return MapObj::CreateFromRange(std::make_move_iterator(std::begin(src)),
                                   std::make_move_iterator(std::end(src)));
  }

  template <typename MapType, bool CanReserve>
  TVM_FFI_INLINE static MapType ConstructMap(const TVMFFIAny* src) {
    using KeyType = typename MapType::key_type;
    using ValueType = typename MapType::mapped_type;
    auto result = MapType{};
    auto map_obj = CopyFromAnyImpl<MapObj>(src);
    if constexpr (CanReserve) {
      result.reserve(map_obj->size());
    }
    for (const auto& [key, value] : *map_obj) {
      result.try_emplace(ConstructFromAny<KeyType>(key), ConstructFromAny<ValueType>(value));
    }
    return result;
  }
};

template <typename T, std::size_t Nm>
struct TypeTraits<std::array<T, Nm>> : public TypeTraits<details::ListTemplate> {
 private:
  using Self = std::array<T, Nm>;

  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
    return n.size_ == Nm;
  }

 public:
  static_assert(Nm > 0, "Zero-length std::array is not supported.");

  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return MoveToAnyImpl(CopyToArray(src), result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return MoveToAnyImpl(MoveToArray(std::move(src)), result);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (!CheckAnyFast(src)) return std::nullopt;
    try {
      auto array = CopyFromAnyImpl<ArrayObj>(src);
      auto begin = array->MutableBegin();
      Self result;  // no initialization to avoid overhead
      for (std::size_t i = 0; i < Nm; ++i) {
        result[i] = ConstructFromAny<T>(begin[i]);
      }
      return result;
    } catch (const details::STLTypeMismatch&) {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return "std::array<" + details::Type2Str<T>::v() + ", " + std::to_string(Nm) + ">";
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    return R"({"type":"std::array","args":[)" + details::TypeSchema<T>::v() + "," +
           std::to_string(Nm) + "]}";
  }
};

template <typename T>
struct TypeTraits<std::vector<T>> : public TypeTraits<details::ListTemplate> {
 private:
  using Self = std::vector<T>;

  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
    return src->type_index == TypeIndex::kTVMFFIArray;
  }

 public:
  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return MoveToAnyImpl(CopyToArray(src), result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return MoveToAnyImpl(MoveToArray(std::move(src)), result);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (!CheckAnyFast(src)) return std::nullopt;
    try {
      auto array = CopyFromAnyImpl<ArrayObj>(src);
      auto begin = array->MutableBegin();
      auto result = Self{};
      int64_t length = array->size_;
      result.reserve(length);
      for (int64_t i = 0; i < length; ++i) {
        result.emplace_back(ConstructFromAny<T>(begin[i]));
      }
      return result;
    } catch (const details::STLTypeMismatch&) {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return "std::vector<" + details::Type2Str<T>::v() + ">";
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    return R"({"type":"std::vector","args":[)" + details::TypeSchema<T>::v() + "]}";
  }
};

template <typename T>
struct TypeTraits<std::optional<T>> : public TypeTraitsBase {
 public:
  using Self = std::optional<T>;

  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    if (src.has_value()) {
      TypeTraits<T>::CopyToAnyView(*src, result);
    } else {
      TypeTraits<std::nullptr_t>::CopyToAnyView(nullptr, result);
    }
  }
  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    if (src.has_value()) {
      TypeTraits<T>::MoveToAny(std::move(*src), result);
    } else {
      TypeTraits<std::nullptr_t>::MoveToAny(nullptr, result);
    }
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) return true;
    return TypeTraits<T>::CheckAnyStrict(src);
  }

  TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
    return TypeTraits<T>::CopyFromAnyViewAfterCheck(src);
  }

  TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
    return TypeTraits<T>::MoveFromAnyAfterCheck(src);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) return Self{std::nullopt};
    auto result = std::optional<Self>{};
    if (std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(src)) {
      result.emplace(std::move(opt));
    } else {
      result.reset();  // failed to cast, indicate failure
    }
    return result;
  }

  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
    return TypeTraits<T>::GetMismatchTypeInfo(src);
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return "std::optional<" + TypeTraits<T>::TypeStr() + ">";
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    return R"({"type":"std::optional","args":[)" + details::TypeSchema<T>::v() + "]}";
  }
};

template <typename... Args>
struct TypeTraits<std::variant<Args...>> : public TypeTraitsBase {
 private:
  using Self = std::variant<Args...>;
  static constexpr std::size_t Nm = sizeof...(Args);

  template <std::size_t Is = 0>
  TVM_FFI_INLINE static Self CopyUnsafeAux(const TVMFFIAny* src) {
    if constexpr (Is >= Nm) {
      TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
      throw;  // unreachable
    } else {
      using ElemType = std::variant_alternative_t<Is, Self>;
      if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
        return Self{std::in_place_index<Is>, TypeTraits<ElemType>::CopyFromAnyViewAfterCheck(src)};
      } else {
        return CopyUnsafeAux<Is + 1>(src);
      }
    }
  }

  template <std::size_t Is = 0>
  TVM_FFI_INLINE static Self MoveUnsafeAux(const TVMFFIAny* src) {
    if constexpr (Is >= Nm) {
      TVM_FFI_ICHECK(false) << "Unreachable: variant TryCast failed.";
      throw;  // unreachable
    } else {
      using ElemType = std::variant_alternative_t<Is, Self>;
      if (TypeTraits<ElemType>::CheckAnyStrict(src)) {
        return Self{std::in_place_index<Is>, TypeTraits<ElemType>::MoveFromAnyAfterCheck(src)};
      } else {
        return MoveUnsafeAux<Is + 1>(src);
      }
    }
  }

  template <std::size_t Is = 0>
  TVM_FFI_INLINE static std::optional<Self> TryCastAux(const TVMFFIAny* src) {
    if constexpr (Is >= Nm) {
      return std::nullopt;
    } else {
      using ElemType = std::variant_alternative_t<Is, Self>;
      if (auto opt = TypeTraits<ElemType>::TryCastFromAnyView(src)) {
        return Self{std::in_place_index<Is>, std::move(*opt)};
      } else {
        return TryCastAux<Is + 1>(src);
      }
    }
  }

 public:
  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return std::visit(
        [&](const auto& value) {
          using ValueType = std::decay_t<decltype(value)>;
          TypeTraits<ValueType>::CopyToAnyView(value, result);
        },
        src);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return std::visit(
        [&](auto&& value) {
          using ValueType = std::decay_t<decltype(value)>;
          TypeTraits<ValueType>::MoveToAny(std::forward<ValueType>(value), result);
        },
        std::move(src));
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    return (TypeTraits<Args>::CheckAnyStrict(src) || ...);
  }

  TVM_FFI_INLINE static Self CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
    // find the first possible type to copy
    return CopyUnsafeAux(src);
  }

  TVM_FFI_INLINE static Self MoveFromAnyAfterCheck(TVMFFIAny* src) {
    // find the first possible type to move
    return MoveUnsafeAux(src);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    // try to find the first possible type to copy
    return TryCastAux(src);
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    std::ostringstream os;
    os << "std::variant<";
    const char* sep = "";
    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
    os << ">";
    return std::move(os).str();
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream os;
    os << R"({"type":"std::variant","args":[)";
    const char* sep = "";
    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
    os << "]}";
    return std::move(os).str();
  }
};

template <typename... Args>
struct TypeTraits<std::tuple<Args...>> : public TypeTraits<details::ListTemplate> {
 private:
  using Self = std::tuple<Args...>;
  static constexpr std::size_t Nm = sizeof...(Args);
  static_assert(Nm > 0, "Zero-length std::tuple is not supported.");

  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
    return n.size_ == Nm;
  }

  template <std::size_t... Is>
  TVM_FFI_INLINE static Self ConstructTupleAux(std::index_sequence<Is...>, const ArrayObj& n) {
    return Self{ConstructFromAny<std::tuple_element_t<Is, Self>>(n[Is])...};
  }

 public:
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;

  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return MoveToAnyImpl(CopyToTuple(src), result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return MoveToAnyImpl(MoveToTuple(std::move(src)), result);
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
    const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
    // check static length first
    if (n.size_ != Nm) return false;
    // then check element type
    return CheckSubTypeAux(std::make_index_sequence<Nm>{}, n);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (!CheckAnyFast(src)) return std::nullopt;
    try {
      auto array = CopyFromAnyImpl<ArrayObj>(src);
      return ConstructTupleAux(std::make_index_sequence<Nm>{}, *array);
    } catch (const details::STLTypeMismatch&) {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    std::ostringstream os;
    os << "std::tuple<";
    const char* sep = "";
    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
    os << ">";
    return std::move(os).str();
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream os;
    os << R"({"type":"std::tuple","args":[)";
    const char* sep = "";
    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
    os << "]}";
    return std::move(os).str();
  }
};

template <typename K, typename V>
struct TypeTraits<std::map<K, V>> : public TypeTraits<details::MapTemplate> {
 private:
  using Self = std::map<K, V>;
  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
    return src->type_index == TypeIndex::kTVMFFIMap;
  }

 public:
  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return MoveToAnyImpl(CopyToMap(src), result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return MoveToAnyImpl(MoveToMap(std::move(src)), result);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (!CheckAnyFast(src)) return std::nullopt;
    try {
      return ConstructMap<Self, /*CanReserve=*/false>(src);
    } catch (const details::STLTypeMismatch&) {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return "std::map<" + details::Type2Str<K>::v() + ", " + details::Type2Str<V>::v() + ">";
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    return R"({"type":"std::map","args":[)" + details::TypeSchema<K>::v() + "," +
           details::TypeSchema<V>::v() + "]}";
  }
};

template <typename K, typename V>
struct TypeTraits<std::unordered_map<K, V>> : public TypeTraits<details::MapTemplate> {
 private:
  using Self = std::unordered_map<K, V>;
  TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
    return src->type_index == TypeIndex::kTVMFFIMap;
  }

 public:
  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return MoveToAnyImpl(CopyToMap(src), result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return MoveToAnyImpl(MoveToMap(std::move(src)), result);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    if (!CheckAnyFast(src)) return std::nullopt;
    try {
      return ConstructMap<Self, /*CanReserve=*/true>(src);
    } catch (const details::STLTypeMismatch&) {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return "std::unordered_map<" + details::Type2Str<K>::v() + ", " + details::Type2Str<V>::v() +
           ">";
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    return R"({"type":"std::unordered_map","args":[)" + details::TypeSchema<K>::v() + "," +
           details::TypeSchema<V>::v() + "]}";
  }
};

template <typename Ret, typename... Args>
struct TypeTraits<std::function<Ret(Args...)>> : TypeTraitsBase {
 private:
  using Self = std::function<Ret(Args...)>;
  using Function = TypedFunction<Ret(Args...)>;
  using ProxyTrait = TypeTraits<Function>;

  TVM_FFI_INLINE static Self GetFunc(Function&& f) {
    return [fn = std::move(f)](Args... args) -> Ret { return fn(std::forward<Args>(args)...); };
  }

 public:
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction;
  static constexpr bool storage_enabled = false;

  TVM_FFI_INLINE static void CopyToAnyView(const Self& src, TVMFFIAny* result) {
    return ProxyTrait::MoveToAny(Function{src}, result);
  }

  TVM_FFI_INLINE static void MoveToAny(Self&& src, TVMFFIAny* result) {
    return ProxyTrait::MoveToAny(Function{std::move(src)}, result);
  }

  TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const TVMFFIAny* src) {
    auto opt = ProxyTrait::TryCastFromAnyView(src);
    if (opt.has_value()) {
      return GetFunc(std::move(*opt));
    } else {
      return std::nullopt;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    std::ostringstream os;
    os << "std::function<" << details::Type2Str<Ret>::v() << "(";
    const char* sep = "";
    ((os << sep << details::Type2Str<Args>::v(), sep = ", "), ...);
    os << ")>";
    return std::move(os).str();
  }

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream os;
    os << R"({"type":"std::function","args":[)" << details::TypeSchema<Ret>::v() << ",[";
    const char* sep = "";
    ((os << sep << details::TypeSchema<Args>::v(), sep = ", "), ...);
    os << "]]}";
    return std::move(os).str();
  }
};

}  // namespace ffi
}  // namespace tvm

#endif  // TVM_FFI_EXTRA_STL_H_