Program Listing for File variant.h

Program Listing for File variant.h#

Return to documentation for file (tvm/ffi/container/variant.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_VARIANT_H_
#define TVM_FFI_CONTAINER_VARIANT_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
#include <tvm/ffi/optional.h>

#include <string>
#include <type_traits>
#include <utility>

namespace tvm {
namespace ffi {
namespace details {
template <bool all_storage_object = false>
class VariantBase {
 public:
  TVM_FFI_INLINE bool same_as(const VariantBase<all_storage_object>& other) const {
    return data_.same_as(other.data_);
  }

 protected:
  template <typename T>
  explicit VariantBase(T other) : data_(std::move(other)) {}

  TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); }

  TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); }

  TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); }

  Any data_;
};

// Specialization for all object ref case, backed by ObjectRef.
template <>
class VariantBase<true> : public ObjectRef {
 protected:
  template <typename T>
  explicit VariantBase(const T& other) : ObjectRef(other) {}
  template <typename T>
  explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
  explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {}
  explicit VariantBase(Any other)
      : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(other))) {}

  TVM_FFI_INLINE void SetData(ObjectPtr<Object> other) { data_ = std::move(other); }

  TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); }

  TVM_FFI_INLINE AnyView ToAnyView() const {
    TVMFFIAny any_data;
    if (data_ == nullptr) {
      any_data.type_index = TypeIndex::kTVMFFINone;
      any_data.zero_padding = 0;
      any_data.v_int64 = 0;
    } else {
      TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
      any_data.type_index = data_->type_index();
      any_data.zero_padding = 0;
      any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
    }
    return AnyView::CopyFromTVMFFIAny(any_data);
  }
};
}  // namespace details

template <typename... V>
class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
 public:
  using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
  static_assert(details::all_storage_enabled_v<V...>,
                "All types used in Variant<...> must be compatible with Any");
  /*
   * \brief Helper utility to check if the type can be contained in the variant
   */
  template <typename T>
  static constexpr bool variant_contains_v = (details::type_contains_v<V, T> || ...);
  /* \brief Helper utility for SFINAE if the type is part of the variant */
  template <typename T>
  using enable_if_variant_contains_t = std::enable_if_t<variant_contains_v<T>>;

  Variant(const Variant<V...>& other) : TParent(other.data_) {}
  Variant(Variant<V...>&& other) : TParent(std::move(other.data_)) {}

  TVM_FFI_INLINE Variant& operator=(const Variant<V...>& other) {
    this->SetData(other.data_);
    return *this;
  }

  TVM_FFI_INLINE Variant& operator=(Variant<V...>&& other) {
    this->SetData(std::move(other.data_));
    return *this;
  }

  template <typename T, typename = enable_if_variant_contains_t<T>>
  Variant(T other) : TParent(std::move(other)) {}  // NOLINT(*)

  template <typename T, typename = enable_if_variant_contains_t<T>>
  TVM_FFI_INLINE Variant& operator=(T other) {
    return operator=(Variant(std::move(other)));
  }

  template <typename T, typename = enable_if_variant_contains_t<T>>
  TVM_FFI_INLINE std::optional<T> as() const {
    return this->TParent::ToAnyView().template as<T>();
  }

  template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
  TVM_FFI_INLINE const T* as() const {
    return this->TParent::ToAnyView().template as<const T*>().value_or(nullptr);
  }

  template <typename T, typename = enable_if_variant_contains_t<T>>
  TVM_FFI_INLINE T get() const& {
    return this->TParent::ToAnyView().template cast<T>();
  }

  template <typename T, typename = enable_if_variant_contains_t<T>>
  TVM_FFI_INLINE T get() && {
    return std::move(*this).TParent::MoveToAny().template cast<T>();
  }

  TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); }

 private:
  friend struct TypeTraits<Variant<V...>>;
  friend struct ObjectPtrHash;
  friend struct ObjectPtrEqual;
  // constructor from any
  explicit Variant(Any data) : TParent(std::move(data)) {}
  TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const {
    constexpr bool all_object_v = (std::is_base_of_v<ObjectRef, V> && ...);
    static_assert(all_object_v,
                  "All types used in Variant<...> must be derived from ObjectRef "
                  "to enable ObjectPtrHash/ObjectPtrEqual");
    return this->data_.get();
  }
  // rexpose to friend class
  using TParent::MoveToAny;
  using TParent::ToAnyView;
};

template <typename... V>
inline constexpr bool use_default_type_traits_v<Variant<V...>> = false;

template <typename... V>
struct TypeTraits<Variant<V...>> : public TypeTraitsBase {
  TVM_FFI_INLINE static void CopyToAnyView(const Variant<V...>& src, TVMFFIAny* result) {
    *result = src.ToAnyView().CopyToTVMFFIAny();
  }

  TVM_FFI_INLINE static void MoveToAny(Variant<V...> src, TVMFFIAny* result) {
    *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny());
  }

  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
    return TypeTraitsBase::GetMismatchTypeInfo(src);
  }

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    return (TypeTraits<V>::CheckAnyStrict(src) || ...);
  }

  TVM_FFI_INLINE static Variant<V...> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
    return Variant<V...>(Any(AnyView::CopyFromTVMFFIAny(*src)));
  }

  TVM_FFI_INLINE static Variant<V...> MoveFromAnyAfterCheck(TVMFFIAny* src) {
    return Variant<V...>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
  }

  TVM_FFI_INLINE static std::optional<Variant<V...>> TryCastFromAnyView(const TVMFFIAny* src) {
    // fast path, storage is already in the right type
    if (CheckAnyStrict(src)) {
      return CopyFromAnyViewAfterCheck(src);
    }
    // More expensive path, try to convert to each type, in order of declaration
    return TryVariantTypes<V...>(src);
  }

  template <typename VariantType, typename... Rest>
  TVM_FFI_INLINE static std::optional<Variant<V...>> TryVariantTypes(const TVMFFIAny* src) {
    if (auto opt_convert = TypeTraits<VariantType>::TryCastFromAnyView(src)) {
      return Variant<V...>(*std::move(opt_convert));
    }
    if constexpr (sizeof...(Rest) > 0) {
      return TryVariantTypes<Rest...>(src);
    }
    return std::nullopt;
  }

  TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr<V...>("Variant"); }
};

template <typename... V>
TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant<V...>& a) const {
  return std::hash<Object*>()(a.GetObjectPtrForHashEqual());
}

template <typename... V>
TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant<V...>& a,
                                               const Variant<V...>& b) const {
  return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual();
}

namespace details {
template <typename... V, typename T>
inline constexpr bool type_contains_v<Variant<V...>, T> = (type_contains_v<V, T> || ...);
}  // namespace details
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_CONTAINER_VARIANT_H_