Program Listing for File object.h#
↰ Return to documentation for file (tvm/ffi/object.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_OBJECT_H_
#define TVM_FFI_OBJECT_H_
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
namespace tvm {
namespace ffi {
using TypeIndex = TVMFFITypeIndex;
using TypeInfo = TVMFFITypeInfo;
struct UnsafeInit {};
struct StaticTypeKey {
static constexpr const char* kTVMFFIAny = "Any";
static constexpr const char* kTVMFFINone = "None";
static constexpr const char* kTVMFFIBool = "bool";
static constexpr const char* kTVMFFIInt = "int";
static constexpr const char* kTVMFFIFloat = "float";
static constexpr const char* kTVMFFIOpaquePtr = "void*";
static constexpr const char* kTVMFFIDataType = "DataType";
static constexpr const char* kTVMFFIDevice = "Device";
static constexpr const char* kTVMFFIDLTensorPtr = "DLTensor*";
static constexpr const char* kTVMFFIRawStr = "const char*";
static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*";
static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef";
static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
static constexpr const char* kTVMFFIBytes = "ffi.Bytes";
static constexpr const char* kTVMFFIStr = "ffi.String";
static constexpr const char* kTVMFFIShape = "ffi.Shape";
static constexpr const char* kTVMFFITensor = "ffi.Tensor";
static constexpr const char* kTVMFFIObject = "ffi.Object";
static constexpr const char* kTVMFFIFunction = "ffi.Function";
static constexpr const char* kTVMFFIArray = "ffi.Array";
static constexpr const char* kTVMFFIMap = "ffi.Map";
static constexpr const char* kTVMFFIModule = "ffi.Module";
static constexpr const char* kTVMFFIOpaquePyObject = "ffi.OpaquePyObject";
};
inline std::string TypeIndexToTypeKey(int32_t type_index) {
const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
return std::string(type_info->type_key.data, type_info->type_key.size);
}
namespace details {
// Helper to perform
// unsafe operations related to object
struct ObjectUnsafe;
constexpr uint64_t kCombinedRefCountWeakOne = static_cast<uint64_t>(1) << 32;
constexpr uint64_t kCombinedRefCountStrongOne = 1;
constexpr uint64_t kCombinedRefCountBothOne = kCombinedRefCountWeakOne | kCombinedRefCountStrongOne;
constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast<uint64_t>(1) << 32) - 1;
template <typename TargetType>
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index);
} // namespace details
class Object {
protected:
TVMFFIObject header_;
public:
Object() {
header_.combined_ref_count = 0;
header_.deleter = nullptr;
}
template <typename TargetType>
bool IsInstance() const {
return details::IsObjectInstance<TargetType>(header_.type_index);
}
int32_t type_index() const { return header_.type_index; }
std::string GetTypeKey() const {
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index);
return std::string(type_info->type_key.data, type_info->type_key.size);
}
uint64_t GetTypeKeyHash() const {
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index);
return type_info->type_key_hash;
}
static std::string TypeIndex2Key(int32_t tindex) {
const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex);
return std::string(type_info->type_key.data, type_info->type_key.size);
}
bool unique() const { return use_count() == 1; }
uint64_t use_count() const {
// only need relaxed load of counters
#ifdef _MSC_VER
return ((reinterpret_cast<const volatile uint64_t*>(
&header_.combined_ref_count))[0] // NOLINT(*)
) &
kCombinedRefCountMaskUInt32;
#else
return __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED) &
kCombinedRefCountMaskUInt32;
#endif
}
//----------------------------------------------------------------------------
// The following fields are configuration flags for subclasses of object
//----------------------------------------------------------------------------
static constexpr const char* _type_key = StaticTypeKey::kTVMFFIObject;
static constexpr bool _type_final = false;
static constexpr bool _type_mutable = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject;
static constexpr int32_t _type_depth = 0;
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported;
// The following functions are provided by macro
// TVM_FFI_DECLARE_OBJECT_INFO and TVM_FFI_DECLARE_OBJECT_INFO_FINAL
static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; }
static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; }
private:
// exposing detailed constants to here
static constexpr uint64_t kCombinedRefCountMaskUInt32 = details::kCombinedRefCountMaskUInt32;
static constexpr uint64_t kCombinedRefCountStrongOne = details::kCombinedRefCountStrongOne;
static constexpr uint64_t kCombinedRefCountWeakOne = details::kCombinedRefCountWeakOne;
static constexpr uint64_t kCombinedRefCountBothOne = details::kCombinedRefCountBothOne;
void IncRef() {
#ifdef _MSC_VER
_InterlockedIncrement64(
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count)); // NOLINT(*)
#else
__atomic_fetch_add(&(header_.combined_ref_count), 1, __ATOMIC_RELAXED);
#endif
}
bool TryPromoteWeakPtr() {
#ifdef _MSC_VER
uint64_t old_count =
(reinterpret_cast<const volatile __int64*>(&header_.combined_ref_count))[0]; // NOLINT(*)
while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
uint64_t new_count = old_count + kCombinedRefCountStrongOne;
uint64_t old_count_loaded = _InterlockedCompareExchange64(
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count), new_count, old_count);
if (old_count == old_count_loaded) {
return true;
}
old_count = old_count_loaded;
}
return false;
#else
uint64_t old_count = __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED);
while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
// must do CAS to ensure that we are the only one that increases the reference count
// avoid condition when two threads tries to promote weak to strong at same time
// or when strong deletion happens between the load and the CAS
uint64_t new_count = old_count + kCombinedRefCountStrongOne;
if (__atomic_compare_exchange_n(&(header_.combined_ref_count), &old_count, new_count, true,
__ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) {
return true;
}
}
return false;
#endif
}
void IncWeakRef() {
#ifdef _MSC_VER
_InlineInterlockedAdd64(
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count), // NOLINT(*)
kCombinedRefCountWeakOne);
#else
__atomic_fetch_add(&(header_.combined_ref_count), kCombinedRefCountWeakOne, __ATOMIC_RELAXED);
#endif
}
void DecRef() {
#ifdef _MSC_VER
// use simpler impl in windows to ensure correctness
uint64_t count_before_sub =
_InterlockedDecrement64( //
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count) // NOLINT(*)
) +
1;
if (count_before_sub == kCombinedRefCountBothOne) { // NOLINT(*)
// fast path: both reference counts will go to zero
if (header_.deleter != nullptr) {
// full barrrier is implicit in InterlockedDecrement
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
}
} else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) {
// strong reference count becomes zero, we need to first do strong deletion
// then decrease weak reference count
// full barrrier is implicit in InterlockedAdd
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong);
}
// decrease weak reference count
if (_InlineInterlockedAdd64( //
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),
-kCombinedRefCountWeakOne) == 0) { // NOLINT(*)
if (header_.deleter != nullptr) {
// full barrrier is implicit in InterlockedAdd
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
}
#else
// first do a release, note we only need to acquire for deleter
// optimization: we only need one atomic to tell the common case
// where both reference counts are zero
uint64_t count_before_sub = __atomic_fetch_sub(&(header_.combined_ref_count),
kCombinedRefCountStrongOne, __ATOMIC_RELEASE);
if (count_before_sub == kCombinedRefCountBothOne) {
// common case, we need to delete both the object and the memory block
// only acquire when we need to call deleter
__atomic_thread_fence(__ATOMIC_ACQUIRE);
if (header_.deleter != nullptr) {
// call deleter once
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
}
} else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) {
// strong count is already zero
// Slower path: there is still a weak reference left
__atomic_thread_fence(__ATOMIC_ACQUIRE);
// call destructor first, then decrease weak reference count
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong);
}
// now decrease weak reference count
if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne,
__ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
}
#endif
}
void DecWeakRef() {
#ifdef _MSC_VER
if (_InlineInterlockedAdd64( //
reinterpret_cast<volatile __int64*>(&header_.combined_ref_count), // NOLINT(*)
-kCombinedRefCountWeakOne) == 0) {
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
#else
// now decrease weak reference count
if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne,
__ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
#endif
}
// friend classes
template <typename>
friend class ObjectPtr;
template <typename>
friend class WeakObjectPtr;
friend struct tvm::ffi::details::ObjectUnsafe;
};
template <typename T>
class ObjectPtr {
public:
ObjectPtr() {}
ObjectPtr(std::nullptr_t) {} // NOLINT(*)
ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
: ObjectPtr(other.data_) {}
template <typename U>
ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
: ObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,
"can only assign of child class ObjectPtr to parent");
}
ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
template <typename Y>
ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
~ObjectPtr() { this->reset(); }
void swap(ObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
T* get() const { return static_cast<T*>(data_); }
T* operator->() const { return get(); }
T& operator*() const { // NOLINT(*)
return *get();
}
ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
ObjectPtr(other).swap(*this); // NOLINT(*)
return *this;
}
ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
explicit operator bool() const { return get() != nullptr; }
void reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
bool operator==(std::nullptr_t) const { return data_ == nullptr; }
bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
private:
Object* data_{nullptr};
explicit ObjectPtr(Object* data) : data_(data) {
if (data_ != nullptr) {
data_->IncRef();
}
}
// friend classes
friend class Object;
friend class ObjectRef;
friend struct ObjectPtrHash;
template <typename>
friend class ObjectPtr;
template <typename>
friend class WeakObjectPtr;
friend struct tvm::ffi::details::ObjectUnsafe;
};
template <typename T>
class WeakObjectPtr {
public:
WeakObjectPtr() {}
WeakObjectPtr(std::nullptr_t) {} // NOLINT(*)
WeakObjectPtr(const WeakObjectPtr<T>& other) // NOLINT(*)
: WeakObjectPtr(other.data_) {}
WeakObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
: WeakObjectPtr(other.get()) {}
template <typename U>
WeakObjectPtr(const WeakObjectPtr<U>& other) // NOLINT(*)
: WeakObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,
"can only assign of child class ObjectPtr to parent");
}
template <typename U>
WeakObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
: WeakObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,
"can only assign of child class ObjectPtr to parent");
}
WeakObjectPtr(WeakObjectPtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
template <typename Y>
WeakObjectPtr(WeakObjectPtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
~WeakObjectPtr() { this->reset(); }
void swap(WeakObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
WeakObjectPtr<T>& operator=(const WeakObjectPtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
WeakObjectPtr(other).swap(*this); // NOLINT(*)
return *this;
}
WeakObjectPtr<T>& operator=(WeakObjectPtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
ObjectPtr<T> lock() const {
if (data_ != nullptr && data_->TryPromoteWeakPtr()) {
ObjectPtr<T> ret;
// we already increase the reference count, so we don't need to do it again
ret.data_ = data_;
return ret;
}
return nullptr;
}
void reset() {
if (data_ != nullptr) {
data_->DecWeakRef();
data_ = nullptr;
}
}
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
bool expired() const { return data_ == nullptr || data_->use_count() == 0; }
private:
Object* data_{nullptr};
explicit WeakObjectPtr(Object* data) : data_(data) {
if (data_ != nullptr) {
data_->IncWeakRef();
}
}
template <typename>
friend class WeakObjectPtr;
friend struct tvm::ffi::details::ObjectUnsafe;
};
template <typename T, typename = void>
class Optional;
class ObjectRef {
public:
ObjectRef() = default;
ObjectRef(const ObjectRef& other) = default;
ObjectRef(ObjectRef&& other) = default;
ObjectRef& operator=(const ObjectRef& other) = default;
ObjectRef& operator=(ObjectRef&& other) = default;
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
explicit ObjectRef(UnsafeInit) : data_(nullptr) {}
bool same_as(const ObjectRef& other) const { return data_ == other.data_; }
bool operator==(const ObjectRef& other) const { return data_ == other.data_; }
bool operator!=(const ObjectRef& other) const { return data_ != other.data_; }
bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); }
bool defined() const { return data_ != nullptr; }
const Object* get() const { return data_.get(); }
const Object* operator->() const { return get(); }
bool unique() const { return data_.unique(); }
int use_count() const { return data_.use_count(); }
template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
const ObjectType* as() const {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
if (data_ != nullptr) {
if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
ObjectRefType ref(UnsafeInit{});
ref.data_ = data_;
return ref;
} else {
return std::nullopt;
}
} else {
return std::nullopt;
}
}
int32_t type_index() const {
return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone;
}
std::string GetTypeKey() const {
return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone;
}
using ContainerType = Object;
static constexpr bool _type_is_nullable = true;
protected:
ObjectPtr<Object> data_;
Object* get_mutable() const { return data_.get(); }
// friend classes.
friend struct ObjectPtrHash;
friend struct tvm::ffi::details::ObjectUnsafe;
};
// forward delcare variant
template <typename... V>
class Variant;
struct ObjectPtrHash {
size_t operator()(const ObjectRef& a) const { return operator()(a.data_); }
template <typename T>
size_t operator()(const ObjectPtr<T>& a) const {
return std::hash<Object*>()(a.get());
}
template <typename... V>
TVM_FFI_INLINE size_t operator()(const Variant<V...>& a) const;
};
struct ObjectPtrEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); }
template <typename T>
bool operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
return a == b;
}
template <typename... V>
TVM_FFI_INLINE bool operator()(const Variant<V...>& a, const Variant<V...>& b) const;
};
#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \
static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \
static int32_t _GetOrAllocRuntimeTypeIndex() { \
static_assert(!ParentType::_type_final, "ParentType marked as final"); \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
TVMFFIByteArray type_key{TypeName::_type_key, \
std::char_traits<char>::length(TypeName::_type_key)}; \
static int32_t tindex = TVMFFITypeGetOrAllocIndex( \
&type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \
TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \
return tindex; \
} \
static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex()
#define TVM_FFI_DECLARE_OBJECT_INFO_STATIC(TypeKey, TypeName, ParentType) \
static constexpr const char* _type_key = TypeKey; \
static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \
TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType)
#define TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) \
static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \
static int32_t _GetOrAllocRuntimeTypeIndex() { \
static_assert(!ParentType::_type_final, "ParentType marked as final"); \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
TVMFFIByteArray type_key{TypeName::_type_key, \
std::char_traits<char>::length(TypeName::_type_key)}; \
static int32_t tindex = TVMFFITypeGetOrAllocIndex( \
&type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \
TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \
return tindex; \
} \
static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \
static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex()
#define TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) \
static constexpr const char* _type_key = TypeKey; \
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType)
#define TVM_FFI_DECLARE_OBJECT_INFO_FINAL(TypeKey, TypeName, ParentType) \
static const constexpr int _type_child_slots [[maybe_unused]] = 0; \
static const constexpr bool _type_final [[maybe_unused]] = true; \
TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType)
#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypeName, ParentType, ObjectName) \
TypeName() = default; \
explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {} \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*, const ObjectName*>; \
__PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \
__PtrType get() const { return static_cast<__PtrType>(data_.get()); } \
[[maybe_unused]] static constexpr bool _type_is_nullable = true; \
using ContainerType = ObjectName
#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*, const ObjectName*>; \
__PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \
__PtrType get() const { return static_cast<__PtrType>(data_.get()); } \
[[maybe_unused]] static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName
namespace details {
template <typename TargetType>
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) {
static_assert(std::is_base_of_v<Object, TargetType>);
// Everything is a subclass of object.
if constexpr (std::is_same<TargetType, Object>::value) {
return true;
} else if constexpr (TargetType::_type_final) {
// if the target type is a final type
// then we only need to check the equivalence.
return object_type_index == TargetType::RuntimeTypeIndex();
} else {
// Explicitly enclose in else to eliminate this branch early in compilation.
// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
int32_t target_type_index = TargetType::RuntimeTypeIndex();
int32_t begin = target_type_index;
// The condition will be optimized by constant-folding.
if constexpr (TargetType::_type_child_slots != 0) {
// total_slots = child_slots + 1 (including self)
int32_t end = begin + TargetType::_type_child_slots + 1;
if (object_type_index >= begin && object_type_index < end) return true;
} else {
if (object_type_index == begin) return true;
}
if constexpr (TargetType::_type_child_slots_can_overflow) {
// Invariance: parent index is always smaller than the child.
if (object_type_index < target_type_index) return false;
// Do a runtime lookup of type information
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
return (type_info->type_depth > TargetType::_type_depth &&
type_info->type_ancestors[TargetType::_type_depth]->type_index == target_type_index);
} else {
return false;
}
}
}
struct ObjectUnsafe {
// NOTE: get ffi header from an object
TVM_FFI_INLINE static TVMFFIObject* GetHeader(const Object* src) {
return const_cast<TVMFFIObject*>(&(src->header_));
}
template <typename Class>
TVM_FFI_INLINE static int64_t GetObjectOffsetToSubclass() {
return (reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->header_)) -
reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_)));
}
template <typename T>
TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr<Object>& ptr) {
T ref(UnsafeInit{});
ref.data_ = ptr;
return ref;
}
template <typename T>
TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr<Object>&& ptr) {
T ref(UnsafeInit{});
ref.data_ = std::move(ptr);
return ref;
}
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromObjectRef(const ObjectRef& ref) {
if constexpr (std::is_same_v<T, Object>) {
return ref.data_;
} else {
return tvm::ffi::ObjectPtr<T>(ref.data_.data_);
}
}
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromObjectRef(ObjectRef&& ref) {
if constexpr (std::is_same_v<T, Object>) {
return std::move(ref.data_);
} else {
ObjectPtr<T> result;
result.data_ = std::move(ref.data_.data_);
ref.data_.data_ = nullptr;
return result;
}
}
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) {
tvm::ffi::ObjectPtr<T> ptr;
ptr.data_ = raw_ptr;
return ptr;
}
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromOwned(TVMFFIObject* obj_ptr) {
return ObjectPtrFromOwned<T>(reinterpret_cast<Object*>(obj_ptr));
}
template <typename T>
TVM_FFI_INLINE static T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) {
// NOTE: this is important to first cast to Object*
// then cast back to T* because objptr and tptr may not be the same
// depending on how sub-class allocates the space.
return static_cast<T*>(reinterpret_cast<Object*>(obj_ptr));
}
// Create ObjectPtr from unowned ptr
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) {
return tvm::ffi::ObjectPtr<T>(raw_ptr);
}
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromUnowned(TVMFFIObject* obj_ptr) {
return tvm::ffi::ObjectPtr<T>(reinterpret_cast<Object*>(obj_ptr));
}
TVM_FFI_INLINE static void DecRefObjectHandle(TVMFFIObjectHandle handle) {
reinterpret_cast<Object*>(handle)->DecRef();
}
TVM_FFI_INLINE static void IncRefObjectHandle(TVMFFIObjectHandle handle) {
reinterpret_cast<Object*>(handle)->IncRef();
}
TVM_FFI_INLINE static Object* RawObjectPtrFromObjectRef(const ObjectRef& src) {
return src.data_.data_;
}
TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectRef(const ObjectRef& src) {
return GetHeader(src.data_.data_);
}
template <typename T>
TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectPtr(const ObjectPtr<T>& src) {
return GetHeader(src.data_);
}
template <typename T>
TVM_FFI_INLINE static TVMFFIObject* MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr<T>&& src) {
Object* obj_ptr = src.data_;
src.data_ = nullptr;
return GetHeader(obj_ptr);
}
TVM_FFI_INLINE static TVMFFIObject* MoveObjectRefToTVMFFIObjectPtr(ObjectRef&& src) {
Object* obj_ptr = src.data_.data_;
src.data_.data_ = nullptr;
return GetHeader(obj_ptr);
}
};
} // namespace details
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_OBJECT_H_