Program Listing for File any.h

Program Listing for File any.h#

Return to documentation for file (tvm/ffi/any.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_ANY_H_
#define TVM_FFI_ANY_H_

#include <tvm/ffi/c_api.h>
#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>

#include <string>
#include <utility>

namespace tvm {
namespace ffi {

class Any;

namespace details {
// Helper to perform
// unsafe operations related to object
struct AnyUnsafe;
}  // namespace details

class AnyView {
 protected:
  TVMFFIAny data_;
  // Any can see AnyView
  friend class Any;

 public:
  // NOTE: the following functions use style
  // since they are common functions appearing in FFI.
  void reset() {
    data_.type_index = TypeIndex::kTVMFFINone;
    // invariance: always set the union padding part to 0
    data_.zero_padding = 0;
    data_.v_int64 = 0;
  }
  TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, other.data_); }
  TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; }
  AnyView() {
    data_.type_index = TypeIndex::kTVMFFINone;
    data_.zero_padding = 0;
    data_.v_int64 = 0;
  }
  ~AnyView() = default;
  // constructors from any view
  AnyView(const AnyView&) = default;
  AnyView& operator=(const AnyView&) = default;
  AnyView(AnyView&& other) : data_(other.data_) {
    other.data_.type_index = TypeIndex::kTVMFFINone;
    other.data_.zero_padding = 0;
    other.data_.v_int64 = 0;
  }
  TVM_FFI_INLINE AnyView& operator=(AnyView&& other) {
    // copy-and-swap idiom
    AnyView(std::move(other)).swap(*this);  // NOLINT(*)
    return *this;
  }
  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  AnyView(const T& other) {  // NOLINT(*)
    TypeTraits<T>::CopyToAnyView(other, &data_);
  }
  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE AnyView& operator=(const T& other) {  // NOLINT(*)
    // copy-and-swap idiom
    AnyView(other).swap(*this);  // NOLINT(*)
    return *this;
  }

  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE std::optional<T> as() const {
    if (TypeTraits<T>::CheckAnyStrict(&data_)) {
      return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
    } else {
      return std::optional<T>(std::nullopt);
    }
  }
  template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
  TVM_FFI_INLINE const T* as() const {
    return this->as<const T*>().value_or(nullptr);
  }

  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE T cast() const {
    std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
    if (!opt.has_value()) {
      TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                               << TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
                               << TypeTraits<T>::TypeStr() << "`";
    }
    return *std::move(opt);
  }

  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE std::optional<T> try_cast() const {
    return TypeTraits<T>::TryCastFromAnyView(&data_);
  }

  // comparison with nullptr
  TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
    return data_.type_index == TypeIndex::kTVMFFINone;
  }
  TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept {
    return data_.type_index != TypeIndex::kTVMFFINone;
  }
  TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); }
  // The following functions are only used for testing purposes
  TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; }
  TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) {
    AnyView view;
    view.data_ = data;
    return view;
  }
};

namespace details {
TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data,
                                               [[maybe_unused]] size_t extra_any_bytes = 0) {
  if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
    details::ObjectUnsafe::IncRefObjectHandle(data->v_obj);
  } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) {
    if (data->type_index == TypeIndex::kTVMFFIRawStr) {
      // convert raw string to owned string object
      String temp(data->v_c_str);
      TypeTraits<String>::MoveToAny(std::move(temp), data);
    } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
      // convert byte array to owned bytes object
      Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr));
      TypeTraits<Bytes>::MoveToAny(std::move(temp), data);
    } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
      // convert rvalue ref to owned object
      Object** obj_addr = static_cast<Object**>(data->v_ptr);
      TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved";
      ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0]));
      // set the rvalue ref to nullptr to avoid double move
      obj_addr[0] = nullptr;
      TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data);
    }
  }
}
}  // namespace details

class Any {
 protected:
  TVMFFIAny data_;

 public:
  TVM_FFI_INLINE void reset() {
    if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
      details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
    }
    data_.type_index = TVMFFITypeIndex::kTVMFFINone;
    data_.zero_padding = 0;
    data_.v_int64 = 0;
  }
  TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, other.data_); }
  TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; }
  Any() {
    data_.type_index = TypeIndex::kTVMFFINone;
    data_.zero_padding = 0;
    data_.v_int64 = 0;
  }
  ~Any() { this->reset(); }
  Any(const Any& other) : data_(other.data_) {
    if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
      details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj);
    }
  }
  Any(Any&& other) : data_(other.data_) {
    other.data_.type_index = TypeIndex::kTVMFFINone;
    other.data_.zero_padding = 0;
    other.data_.v_int64 = 0;
  }
  TVM_FFI_INLINE Any& operator=(const Any& other) {
    // copy-and-swap idiom
    Any(other).swap(*this);  // NOLINT(*)
    return *this;
  }
  TVM_FFI_INLINE Any& operator=(Any&& other) {
    // copy-and-swap idiom
    Any(std::move(other)).swap(*this);  // NOLINT(*)
    return *this;
  }
  Any(const AnyView& other) : data_(other.data_) {  // NOLINT(*)
    details::InplaceConvertAnyViewToAny(&data_);
  }
  TVM_FFI_INLINE Any& operator=(const AnyView& other) {
    // copy-and-swap idiom
    Any(other).swap(*this);  // NOLINT(*)
    return *this;
  }
  operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); }
  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  Any(T other) {  // NOLINT(*)
    TypeTraits<T>::MoveToAny(std::move(other), &data_);
  }
  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE Any& operator=(T other) {  // NOLINT(*)
    // copy-and-swap idiom
    Any(std::move(other)).swap(*this);  // NOLINT(*)
    return *this;
  }

  template <typename T,
            typename = std::enable_if_t<TypeTraits<T>::storage_enabled || std::is_same_v<T, Any>>>
  TVM_FFI_INLINE std::optional<T> as() && {
    if constexpr (std::is_same_v<T, Any>) {
      return std::move(*this);
    } else {
      if (TypeTraits<T>::CheckAnyStrict(&data_)) {
        return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
      } else {
        return std::optional<T>(std::nullopt);
      }
    }
  }

  template <typename T,
            typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
  TVM_FFI_INLINE std::optional<T> as() const& {
    if constexpr (std::is_same_v<T, Any>) {
      return *this;
    } else {
      if (TypeTraits<T>::CheckAnyStrict(&data_)) {
        return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
      } else {
        return std::optional<T>(std::nullopt);
      }
    }
  }

  template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
  TVM_FFI_INLINE const T* as() const& {
    return this->as<const T*>().value_or(nullptr);
  }

  template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
  TVM_FFI_INLINE T cast() const& {
    std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
    if (!opt.has_value()) {
      TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                               << TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
                               << TypeTraits<T>::TypeStr() << "`";
    }
    return *std::move(opt);
  }

  template <typename T, typename = std::enable_if_t<TypeTraits<T>::storage_enabled>>
  TVM_FFI_INLINE T cast() && {
    if (TypeTraits<T>::CheckAnyStrict(&data_)) {
      return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
    }
    // slow path, try to do fallback convert
    std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
    if (!opt.has_value()) {
      TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
                               << TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
                               << TypeTraits<T>::TypeStr() << "`";
    }
    return *std::move(opt);
  }

  template <typename T,
            typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
  TVM_FFI_INLINE std::optional<T> try_cast() const {
    if constexpr (std::is_same_v<T, Any>) {
      return *this;
    } else {
      return TypeTraits<T>::TryCastFromAnyView(&data_);
    }
  }
  TVM_FFI_INLINE bool same_as(const Any& other) const noexcept {
    return data_.type_index == other.data_.type_index &&
           data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64;
  }

  TVM_FFI_INLINE bool same_as(const ObjectRef& other) const noexcept {
    if (other.get() != nullptr) {
      return (data_.type_index == other->type_index() &&
              reinterpret_cast<Object*>(data_.v_obj) == other.get());
    } else {
      return data_.type_index == TypeIndex::kTVMFFINone;
    }
  }

  TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
    return data_.type_index == TypeIndex::kTVMFFINone;
  }
  TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept {
    return data_.type_index != TypeIndex::kTVMFFINone;
  }

  TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); }

  friend struct details::AnyUnsafe;
  friend struct AnyHash;
  friend struct AnyEqual;
};

// layout assert to ensure we can freely cast between the two types
static_assert(sizeof(AnyView) == sizeof(TVMFFIAny));
static_assert(sizeof(Any) == sizeof(TVMFFIAny));

namespace details {

template <typename Type>
struct Type2Str {
  static std::string v() { return TypeTraitsNoCR<Type>::TypeStr(); }
};

template <>
struct Type2Str<Any> {
  static std::string v() { return "Any"; }
};

template <>
struct Type2Str<const Any&> {
  static std::string v() { return "Any"; }
};

template <>
struct Type2Str<AnyView> {
  static std::string v() { return "AnyView"; }
};

template <>
struct Type2Str<const AnyView&> {
  static std::string v() { return "AnyView"; }
};

template <>
struct Type2Str<void> {
  static std::string v() { return "void"; }
};

// Extra unsafe method to help any manipulation
struct AnyUnsafe : public ObjectUnsafe {
  // FFI related operations
  TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) {
    TVMFFIAny result = any.data_;
    any.data_.type_index = TypeIndex::kTVMFFINone;
    any.data_.zero_padding = 0;
    any.data_.v_int64 = 0;
    return result;
  }

  TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny&& data) {
    Any any;
    any.data_ = data;
    data.type_index = TypeIndex::kTVMFFINone;
    data.zero_padding = 0;
    data.v_int64 = 0;
    return any;
  }

  template <typename T>
  TVM_FFI_INLINE static bool CheckAnyStrict(const Any& ref) {
    return TypeTraits<T>::CheckAnyStrict(&(ref.data_));
  }

  template <typename T>
  TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any& ref) {
    if constexpr (!std::is_same_v<T, Any>) {
      return TypeTraits<T>::CopyFromAnyViewAfterCheck(&(ref.data_));
    } else {
      return ref;
    }
  }

  template <typename T>
  TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any&& ref) {
    if constexpr (!std::is_same_v<T, Any>) {
      return TypeTraits<T>::MoveFromAnyAfterCheck(&(ref.data_));
    } else {
      return std::move(ref);
    }
  }

  TVM_FFI_INLINE static Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
    return reinterpret_cast<Object*>(ref.data_.v_obj);
  }

  TVM_FFI_INLINE static const TVMFFIAny* TVMFFIAnyPtrFromAny(const Any& ref) {
    return &(ref.data_);
  }

  template <typename T>
  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any& ref) {
    return TypeTraits<T>::GetMismatchTypeInfo(&(ref.data_));
  }
};
}  // namespace details

struct AnyHash {
  uint64_t operator()(const Any& src) const {
    if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) {
      // for small string, we use the same type key hash as normal string
      // so heap allocated string and on stack string will have the same hash
      return details::StableHashCombine(TypeIndex::kTVMFFIStr,
                                        details::StableHashSmallStrBytes(&src.data_));
    } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
      // use byte the same type key as bytes
      return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
                                        details::StableHashSmallStrBytes(&src.data_));
    } else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
               src.data_.type_index == TypeIndex::kTVMFFIBytes) {
      const details::BytesObjBase* src_str =
          details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
      return details::StableHashCombine(src.data_.type_index,
                                        details::StableHashBytes(src_str->data, src_str->size));
    } else {
      return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64);
    }
  }
};

struct AnyEqual {
  bool operator()(const Any& lhs, const Any& rhs) const {
    // header with type index
    const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_);
    const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_);
    static_assert(sizeof(TVMFFIAny) == 16);
    // fast path, check byte equality
    if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) {
      return true;
    }
    // common false case type index match, in this case we only need to pay attention to string
    // equality
    if (lhs.data_.type_index == rhs.data_.type_index) {
      // specialy handle string hash
      if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
          lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
        const details::BytesObjBase* lhs_str =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
        const details::BytesObjBase* rhs_str =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
        return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
      }
      return false;
    } else {
      // type_index mismatch, if index is not string, return false
      if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr &&
          lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) {
        return false;
      }
      // small string and normal string comparison
      if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) {
        const details::BytesObjBase* lhs_str =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
        return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size,
                               rhs.data_.small_str_len);
      }
      if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) {
        const details::BytesObjBase* rhs_str =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
        return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len,
                               rhs_str->size);
      }
      if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) {
        const details::BytesObjBase* lhs_bytes =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
        return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size,
                               rhs.data_.small_str_len);
      }
      if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) {
        const details::BytesObjBase* rhs_bytes =
            details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
        return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len,
                               rhs_bytes->size);
      }
      return false;
    }
  }
};
}  // namespace ffi

// Expose to the tvm namespace for usability
// Rationale: no ambiguity even in root
using tvm::ffi::Any;
using tvm::ffi::AnyView;

}  // namespace tvm
#endif  // TVM_FFI_ANY_H_