Program Listing for File tuple.h

Program Listing for File tuple.h#

Return to documentation for file (tvm/ffi/container/tuple.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_TUPLE_H_
#define TVM_FFI_CONTAINER_TUPLE_H_

#include <tvm/ffi/container/array.h>

#include <string>
#include <tuple>
#include <utility>

namespace tvm {
namespace ffi {

template <typename... Types>
class Tuple : public ObjectRef {
 public:
  static_assert(details::all_storage_enabled_v<Types...>,
                "All types used in Tuple<...> must be compatible with Any");
  Tuple() : ObjectRef(MakeDefaultTupleNode()) {}
  explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {}
  Tuple(const Tuple<Types...>& other) : ObjectRef(other) {}
  Tuple(Tuple<Types...>&& other) : ObjectRef(std::move(other)) {}
  template <typename... UTypes,
            typename = std::enable_if_t<(details::type_contains_v<Types, UTypes> && ...), int>>
  Tuple(const Tuple<UTypes...>& other) : ObjectRef(other) {}

  template <typename... UTypes,
            typename = std::enable_if_t<(details::type_contains_v<Types, UTypes> && ...), int>>
  Tuple(Tuple<UTypes...>&& other) : ObjectRef(std::move(other)) {}

  template <typename... UTypes, typename = std::enable_if_t<
                                    sizeof...(Types) == sizeof...(UTypes) &&
                                    !(sizeof...(Types) == 1 &&
                                      (std::is_same_v<std::decay_t<UTypes>, Tuple<Types>> && ...))>>
  explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward<UTypes>(args)...)) {}

  TVM_FFI_INLINE Tuple& operator=(const Tuple<Types...>& other) {
    data_ = other.data_;
    return *this;
  }

  TVM_FFI_INLINE Tuple& operator=(Tuple<Types...>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  template <typename... UTypes,
            typename = std::enable_if_t<(details::type_contains_v<Types, UTypes> && ...)>>
  TVM_FFI_INLINE Tuple& operator=(const Tuple<UTypes...>& other) {
    data_ = other.data_;
    return *this;
  }

  template <typename... UTypes,
            typename = std::enable_if_t<(details::type_contains_v<Types, UTypes> && ...)>>
  TVM_FFI_INLINE Tuple& operator=(Tuple<UTypes...>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  template <size_t I>
  auto get() const {
    static_assert(I < sizeof...(Types), "Tuple index out of bounds");
    using ReturnType = std::tuple_element_t<I, std::tuple<Types...>>;
    const Any* ptr = GetArrayObj()->begin() + I;
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<ReturnType>(*ptr);
  }

  template <size_t I, typename U>
  void Set(U&& item) {
    static_assert(I < sizeof...(Types), "Tuple index out of bounds");
    using T = std::tuple_element_t<I, std::tuple<Types...>>;
    this->CopyIfNotUnique();
    Any* ptr = GetArrayObj()->MutableBegin() + I;
    *ptr = T(std::forward<U>(item));
  }

  using ContainerType = ArrayObj;

 private:
  static ObjectPtr<ArrayObj> MakeDefaultTupleNode() {
    ObjectPtr<ArrayObj> p = ArrayObj::Empty(sizeof...(Types));
    Any* itr = p->MutableBegin();
    // increase size after each new to ensure exception safety
    ((new (itr++) Any(Types()), p->size_++), ...);
    return p;
  }

  template <typename... UTypes>
  static ObjectPtr<ArrayObj> MakeTupleNode(UTypes&&... args) {
    ObjectPtr<ArrayObj> p = ArrayObj::Empty(sizeof...(Types));
    Any* itr = p->MutableBegin();
    // increase size after each new to ensure exception safety
    ((new (itr++) Any(Types(std::forward<UTypes>(args))), p->size_++), ...);
    return p;
  }

  void CopyIfNotUnique() {
    if (!data_.unique()) {
      ObjectPtr<ArrayObj> p = ArrayObj::Empty(sizeof...(Types));
      Any* itr = p->MutableBegin();
      const Any* read = GetArrayObj()->begin();
      // increase size after each new to ensure exception safety
      for (size_t i = 0; i < sizeof...(Types); ++i) {
        new (itr++) Any(*read++);
        p->size_++;
      }
      data_ = std::move(p);
    }
  }

  ArrayObj* GetArrayObj() const { return static_cast<ArrayObj*>(data_.get()); }

  template <typename... UTypes>
  friend class Tuple;
};

template <typename... Types>
inline constexpr bool use_default_type_traits_v<Tuple<Types...>> = false;

template <typename... Types>
struct TypeTraits<Tuple<Types...>> : public ObjectRefTypeTraitsBase<Tuple<Types...>> {
  using ObjectRefTypeTraitsBase<Tuple<Types...>>::CopyFromAnyViewAfterCheck;

  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
    if (src->type_index != TypeIndex::kTVMFFIArray) {
      return TypeTraitsBase::GetMismatchTypeInfo(src);
    }
    const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
    if (n->size() != sizeof...(Types)) {
      return "Array[size=" + std::to_string(n->size()) + "]";
    }
    return GetMismatchTypeInfoHelper<0, Types...>(n->begin());
  }

  template <size_t I, typename T, typename... Rest>
  TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any* arr) {
    if constexpr (!std::is_same_v<T, Any>) {
      const Any& any_v = arr[I];
      if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v) && !(any_v.try_cast<T>().has_value())) {
        // now report the accurate mismatch information
        return "Array[index " + std::to_string(I) + ": " +
               details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
      }
    }
    if constexpr (sizeof...(Rest) > 0) {
      return GetMismatchTypeInfoHelper<I + 1, Rest...>(arr);
    }
    TVM_FFI_THROW(InternalError) << "Cannot reach here";
    TVM_FFI_UNREACHABLE();
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    if (src->type_index != TypeIndex::kTVMFFIArray) return false;
    const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
    if (n->size() != sizeof...(Types)) return false;
    const TVMFFIAny* ffi_any_arr = reinterpret_cast<const TVMFFIAny*>(n->begin());
    return CheckAnyStrictHelper<0, Types...>(ffi_any_arr);
  }

  template <size_t I, typename T, typename... Rest>
  TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) {
    if constexpr (!std::is_same_v<T, Any>) {
      if (!TypeTraits<T>::CheckAnyStrict(src_arr + I)) {
        return false;
      }
    }
    if constexpr (sizeof...(Rest) > 0) {
      return CheckAnyStrictHelper<I + 1, Rest...>(src_arr);
    }
    return true;
  }

  TVM_FFI_INLINE static std::optional<Tuple<Types...>> TryCastFromAnyView(const TVMFFIAny* src  //
  ) {
    if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
    const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
    if (n->size() != sizeof...(Types)) return std::nullopt;
    // fast path, storage is already in the right type
    if (CheckAnyStrict(src)) {
      return CopyFromAnyViewAfterCheck(src);
    }
    // slow path, try to convert to each type to match the tuple storage need.
    Array<Any> arr = TypeTraits<Array<Any>>::CopyFromAnyViewAfterCheck(src);
    Any* ptr = arr.CopyOnWrite()->MutableBegin();
    if (TryConvertElements<0, Types...>(ptr)) {
      return details::ObjectUnsafe::ObjectRefFromObjectPtr<Tuple<Types...>>(
          details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
    }
    return std::nullopt;
  }

  template <size_t I, typename T, typename... Rest>
  TVM_FFI_INLINE static bool TryConvertElements(Any* arr) {
    if constexpr (!std::is_same_v<T, Any>) {
      if (auto opt_convert = arr[I].try_cast<T>()) {
        arr[I] = *std::move(opt_convert);
      } else {
        return false;
      }
    }
    if constexpr (sizeof...(Rest) > 0) {
      return TryConvertElements<I + 1, Rest...>(std::move(arr));
    } else {
      return true;
    }
  }

  TVM_FFI_INLINE static std::string TypeStr() {
    return details::ContainerTypeStr<Types...>("Tuple");
  }
};

namespace details {
template <typename... T, typename... U>
inline constexpr bool type_contains_v<Tuple<T...>, Tuple<U...>> = (type_contains_v<T, U> && ...);
}  // namespace details

}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_CONTAINER_TUPLE_H_