Program Listing for File any.h#
↰ Return to documentation for file (tvm/ffi/any.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_ANY_H_
#define TVM_FFI_ANY_H_
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>
#include <string>
#include <utility>
namespace tvm {
namespace ffi {
class Any;
namespace details {
// Helper to perform
// unsafe operations related to object
struct AnyUnsafe;
} // namespace details
class AnyView {
protected:
TVMFFIAny data_;
// Any can see AnyView
friend class Any;
public:
// NOTE: the following functions use style
// since they are common functions appearing in FFI.
void reset() {
data_.type_index = TypeIndex::kTVMFFINone;
// invariance: always set the union padding part to 0
data_.zero_padding = 0;
data_.v_int64 = 0;
}
TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, other.data_); }
TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; }
AnyView() {
data_.type_index = TypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
~AnyView() = default;
// constructors from any view
AnyView(const AnyView&) = default;
AnyView& operator=(const AnyView&) = default;
AnyView(AnyView&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE AnyView& operator=(AnyView&& other) {
// copy-and-swap idiom
AnyView(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
AnyView(const T& other) { // NOLINT(*)
TypeTraits<T>::CopyToAnyView(other, &data_);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE AnyView& operator=(const T& other) { // NOLINT(*)
// copy-and-swap idiom
AnyView(other).swap(*this); // NOLINT(*)
return *this;
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE std::optional<T> as() const {
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
} else {
return std::optional<T>(std::nullopt);
}
}
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
TVM_FFI_INLINE const T* as() const {
return this->as<const T*>().value_or(nullptr);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
}
return *std::move(opt);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE std::optional<T> try_cast() const {
return TypeTraits<T>::TryCastFromAnyView(&data_);
}
// comparison with nullptr
TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
return data_.type_index == TypeIndex::kTVMFFINone;
}
TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept {
return data_.type_index != TypeIndex::kTVMFFINone;
}
TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); }
// The following functions are only used for testing purposes
TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; }
TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) {
AnyView view;
view.data_ = data;
return view;
}
};
namespace details {
TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data,
[[maybe_unused]] size_t extra_any_bytes = 0) {
if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
details::ObjectUnsafe::IncRefObjectHandle(data->v_obj);
} else if (data->type_index >= TypeIndex::kTVMFFIRawStr) {
if (data->type_index == TypeIndex::kTVMFFIRawStr) {
// convert raw string to owned string object
String temp(data->v_c_str);
TypeTraits<String>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
// convert byte array to owned bytes object
Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr));
TypeTraits<Bytes>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
// convert rvalue ref to owned object
Object** obj_addr = static_cast<Object**>(data->v_ptr);
TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved";
ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0]));
// set the rvalue ref to nullptr to avoid double move
obj_addr[0] = nullptr;
TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data);
}
}
}
} // namespace details
class Any {
protected:
TVMFFIAny data_;
public:
TVM_FFI_INLINE void reset() {
if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
}
data_.type_index = TVMFFITypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, other.data_); }
TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; }
Any() {
data_.type_index = TypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
~Any() { this->reset(); }
Any(const Any& other) : data_(other.data_) {
if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj);
}
}
Any(Any&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE Any& operator=(const Any& other) {
// copy-and-swap idiom
Any(other).swap(*this); // NOLINT(*)
return *this;
}
TVM_FFI_INLINE Any& operator=(Any&& other) {
// copy-and-swap idiom
Any(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
Any(const AnyView& other) : data_(other.data_) { // NOLINT(*)
details::InplaceConvertAnyViewToAny(&data_);
}
TVM_FFI_INLINE Any& operator=(const AnyView& other) {
// copy-and-swap idiom
Any(other).swap(*this); // NOLINT(*)
return *this;
}
operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); }
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
Any(T other) { // NOLINT(*)
TypeTraits<T>::MoveToAny(std::move(other), &data_);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE Any& operator=(T other) { // NOLINT(*)
// copy-and-swap idiom
Any(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::storage_enabled || std::is_same_v<T, Any>>>
TVM_FFI_INLINE std::optional<T> as() && {
if constexpr (std::is_same_v<T, Any>) {
return std::move(*this);
} else {
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
} else {
return std::optional<T>(std::nullopt);
}
}
}
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
TVM_FFI_INLINE std::optional<T> as() const& {
if constexpr (std::is_same_v<T, Any>) {
return *this;
} else {
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&data_);
} else {
return std::optional<T>(std::nullopt);
}
}
}
template <typename T, typename = std::enable_if_t<std::is_base_of_v<Object, T>>>
TVM_FFI_INLINE const T* as() const& {
return this->as<const T*>().value_or(nullptr);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const& {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
}
return *std::move(opt);
}
template <typename T, typename = std::enable_if_t<TypeTraits<T>::storage_enabled>>
TVM_FFI_INLINE T cast() && {
if (TypeTraits<T>::CheckAnyStrict(&data_)) {
return TypeTraits<T>::MoveFromAnyAfterCheck(&data_);
}
// slow path, try to do fallback convert
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
}
return *std::move(opt);
}
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
TVM_FFI_INLINE std::optional<T> try_cast() const {
if constexpr (std::is_same_v<T, Any>) {
return *this;
} else {
return TypeTraits<T>::TryCastFromAnyView(&data_);
}
}
TVM_FFI_INLINE bool same_as(const Any& other) const noexcept {
return data_.type_index == other.data_.type_index &&
data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64;
}
TVM_FFI_INLINE bool same_as(const ObjectRef& other) const noexcept {
if (other.get() != nullptr) {
return (data_.type_index == other->type_index() &&
reinterpret_cast<Object*>(data_.v_obj) == other.get());
} else {
return data_.type_index == TypeIndex::kTVMFFINone;
}
}
TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept {
return data_.type_index == TypeIndex::kTVMFFINone;
}
TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept {
return data_.type_index != TypeIndex::kTVMFFINone;
}
TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); }
friend struct details::AnyUnsafe;
friend struct AnyHash;
friend struct AnyEqual;
};
// layout assert to ensure we can freely cast between the two types
static_assert(sizeof(AnyView) == sizeof(TVMFFIAny));
static_assert(sizeof(Any) == sizeof(TVMFFIAny));
namespace details {
template <typename Type>
struct Type2Str {
static std::string v() { return TypeTraitsNoCR<Type>::TypeStr(); }
};
template <>
struct Type2Str<Any> {
static std::string v() { return "Any"; }
};
template <>
struct Type2Str<const Any&> {
static std::string v() { return "Any"; }
};
template <>
struct Type2Str<AnyView> {
static std::string v() { return "AnyView"; }
};
template <>
struct Type2Str<const AnyView&> {
static std::string v() { return "AnyView"; }
};
template <>
struct Type2Str<void> {
static std::string v() { return "void"; }
};
// Extra unsafe method to help any manipulation
struct AnyUnsafe : public ObjectUnsafe {
// FFI related operations
TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) {
TVMFFIAny result = any.data_;
any.data_.type_index = TypeIndex::kTVMFFINone;
any.data_.zero_padding = 0;
any.data_.v_int64 = 0;
return result;
}
TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny&& data) {
Any any;
any.data_ = data;
data.type_index = TypeIndex::kTVMFFINone;
data.zero_padding = 0;
data.v_int64 = 0;
return any;
}
template <typename T>
TVM_FFI_INLINE static bool CheckAnyStrict(const Any& ref) {
return TypeTraits<T>::CheckAnyStrict(&(ref.data_));
}
template <typename T>
TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any& ref) {
if constexpr (!std::is_same_v<T, Any>) {
return TypeTraits<T>::CopyFromAnyViewAfterCheck(&(ref.data_));
} else {
return ref;
}
}
template <typename T>
TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any&& ref) {
if constexpr (!std::is_same_v<T, Any>) {
return TypeTraits<T>::MoveFromAnyAfterCheck(&(ref.data_));
} else {
return std::move(ref);
}
}
TVM_FFI_INLINE static Object* ObjectPtrFromAnyAfterCheck(const Any& ref) {
return reinterpret_cast<Object*>(ref.data_.v_obj);
}
TVM_FFI_INLINE static const TVMFFIAny* TVMFFIAnyPtrFromAny(const Any& ref) {
return &(ref.data_);
}
template <typename T>
TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any& ref) {
return TypeTraits<T>::GetMismatchTypeInfo(&(ref.data_));
}
};
} // namespace details
struct AnyHash {
uint64_t operator()(const Any& src) const {
if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) {
// for small string, we use the same type key hash as normal string
// so heap allocated string and on stack string will have the same hash
return details::StableHashCombine(TypeIndex::kTVMFFIStr,
details::StableHashSmallStrBytes(&src.data_));
} else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
// use byte the same type key as bytes
return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
details::StableHashSmallStrBytes(&src.data_));
} else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* src_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
return details::StableHashCombine(src.data_.type_index,
details::StableHashBytes(src_str->data, src_str->size));
} else {
return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64);
}
}
};
struct AnyEqual {
bool operator()(const Any& lhs, const Any& rhs) const {
// header with type index
const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_);
const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_);
static_assert(sizeof(TVMFFIAny) == 16);
// fast path, check byte equality
if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) {
return true;
}
// common false case type index match, in this case we only need to pay attention to string
// equality
if (lhs.data_.type_index == rhs.data_.type_index) {
// specialy handle string hash
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
}
return false;
} else {
// type_index mismatch, if index is not string, return false
if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr &&
lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) {
return false;
}
// small string and normal string comparison
if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) {
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size,
rhs.data_.small_str_len);
}
if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) {
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len,
rhs_str->size);
}
if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) {
const details::BytesObjBase* lhs_bytes =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size,
rhs.data_.small_str_len);
}
if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) {
const details::BytesObjBase* rhs_bytes =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len,
rhs_bytes->size);
}
return false;
}
}
};
} // namespace ffi
// Expose to the tvm namespace for usability
// Rationale: no ambiguity even in root
using tvm::ffi::Any;
using tvm::ffi::AnyView;
} // namespace tvm
#endif // TVM_FFI_ANY_H_