23 #ifndef TVM_RUNTIME_OBJECT_H_
24 #define TVM_RUNTIME_OBJECT_H_
27 #include <tvm/runtime/logging.h>
30 #include <type_traits>
39 #ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
40 #define TVM_OBJECT_ATOMIC_REF_COUNTER 1
43 #if TVM_OBJECT_ATOMIC_REF_COUNTER
184 std::string
GetTypeKey()
const {
return TypeIndex2Key(type_index_); }
194 template <
typename TargetType>
195 inline bool IsInstance()
const;
200 inline bool unique()
const;
220 #if TVM_OBJECT_ATOMIC_REF_COUNTER
226 static constexpr
const char* _type_key =
"runtime.Object";
232 static constexpr
bool _type_final =
false;
233 static constexpr uint32_t _type_child_slots = 0;
234 static constexpr
bool _type_child_slots_can_overflow =
true;
236 static constexpr
bool _type_has_method_visit_attrs =
true;
237 static constexpr
bool _type_has_method_sequal_reduce =
false;
238 static constexpr
bool _type_has_method_shash_reduce =
false;
265 uint32_t type_index_{0};
273 FDeleter deleter_ =
nullptr;
277 "RefCounter ABI check.");
297 uint32_t parent_tindex, uint32_t type_child_slots,
298 bool type_child_slots_can_overflow);
302 inline void IncRef();
307 inline void DecRef();
314 inline int use_count()
const;
320 bool DerivedFrom(uint32_t parent_tindex)
const;
327 friend class ObjectInternal;
342 template <
typename RelayRefType,
typename ObjectType>
353 template <
typename SubRef,
typename BaseRef>
354 inline SubRef
Downcast(BaseRef ref);
361 template <
typename T>
378 template <
typename U>
381 static_assert(std::is_base_of<T, U>::value,
382 "can only assign of child class ObjectPtr to parent");
389 : data_(other.data_) {
390 other.data_ =
nullptr;
396 template <
typename Y>
398 : data_(other.data_) {
399 static_assert(std::is_base_of<T, Y>::value,
400 "can only assign of child class ObjectPtr to parent");
401 other.data_ =
nullptr;
410 std::swap(data_, other.data_);
415 T*
get()
const {
return static_cast<T*
>(data_); }
451 explicit operator bool()
const {
return get() !=
nullptr; }
454 if (data_ !=
nullptr) {
460 int use_count()
const {
return data_ !=
nullptr ? data_->use_count() : 0; }
462 bool unique()
const {
return data_ !=
nullptr && data_->use_count() == 1; }
468 bool operator==(std::nullptr_t
null)
const {
return data_ ==
nullptr; }
470 bool operator!=(std::nullptr_t
null)
const {
return data_ !=
nullptr; }
480 if (data !=
nullptr) {
489 static ObjectPtr<T> MoveFromRValueRefArg(
Object** ref) {
508 template <
typename RelayRefType,
typename ObjType>
510 template <
typename BaseType,
typename ObjType>
515 template <
typename T>
574 template <
typename ObjectType,
typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
575 inline const ObjectType*
as()
const;
594 template <
typename ObjectRefType,
595 typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
614 template <
typename T>
616 return T(std::move(ref.
data_));
630 template <
typename ObjectType>
639 template <
typename SubRef,
typename BaseRef>
640 friend SubRef
Downcast(BaseRef ref);
651 template <
typename BaseType,
typename ObjectType>
658 template <
typename T>
660 return std::hash<Object*>()(a.
get());
668 template <
typename T>
679 #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
680 static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
681 static uint32_t RuntimeTypeIndex() { \
682 static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
683 TypeName::_type_child_slots < ParentType::_type_child_slots, \
684 "Need to set _type_child_slots when parent specifies it."); \
685 if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
686 return TypeName::_type_index; \
688 return _GetOrAllocRuntimeTypeIndex(); \
690 static uint32_t _GetOrAllocRuntimeTypeIndex() { \
691 static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
692 TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
693 TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
702 #define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
703 static const constexpr bool _type_final = true; \
704 static const constexpr int _type_child_slots = 0; \
705 TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
708 #if defined(__GNUC__)
709 #define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
711 #define TVM_ATTRIBUTE_UNUSED
714 #define TVM_STR_CONCAT_(__x, __y) __x##__y
715 #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
717 #define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
725 #define TVM_REGISTER_OBJECT_TYPE(TypeName) \
726 TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex()
732 #define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
733 TypeName(const TypeName& other) = default; \
734 TypeName(TypeName&& other) = default; \
735 TypeName& operator=(const TypeName& other) = default; \
736 TypeName& operator=(TypeName&& other) = default;
744 #define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \
746 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
747 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
748 const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
749 const ObjectName* get() const { return operator->(); } \
750 using ContainerType = ObjectName;
758 #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
759 TypeName() = default; \
760 TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName)
769 #define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
770 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
771 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
772 const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
773 const ObjectName* get() const { return operator->(); } \
774 static constexpr bool _type_is_nullable = false; \
775 using ContainerType = ObjectName;
785 #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
786 TypeName() = default; \
787 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
788 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
789 ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
790 using ContainerType = ObjectName;
799 #define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
800 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
801 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
802 ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
803 ObjectName* get() const { return operator->(); } \
804 static constexpr bool _type_is_nullable = false; \
805 using ContainerType = ObjectName;
826 #define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
827 static_assert(ObjectName::_type_final, \
828 "TVM's CopyOnWrite may only be used for " \
829 "Object types that are declared as final, " \
830 "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \
831 ObjectName* CopyOnWrite() { \
832 ICHECK(data_ != nullptr); \
833 if (!data_.unique()) { \
834 auto n = make_object<ObjectName>(*(operator->())); \
835 ObjectPtr<Object>(std::move(n)).swap(data_); \
837 return static_cast<ObjectName*>(data_.get()); \
842 #if TVM_OBJECT_ATOMIC_REF_COUNTER
847 if (
ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
848 std::atomic_thread_fence(std::memory_order_acquire);
849 if (this->deleter_ !=
nullptr) {
855 inline int Object::use_count()
const {
return ref_counter_.load(std::memory_order_relaxed); }
863 if (this->deleter_ !=
nullptr) {
869 inline int Object::use_count()
const {
return ref_counter_; }
873 template <
typename TargetType>
875 const Object*
self =
this;
878 if (
self !=
nullptr) {
880 if (std::is_same<TargetType, Object>::value)
return true;
881 if (TargetType::_type_final) {
884 return self->
type_index_ == TargetType::RuntimeTypeIndex();
888 uint32_t begin = TargetType::RuntimeTypeIndex();
890 if (TargetType::_type_child_slots != 0) {
891 uint32_t end = begin + TargetType::_type_child_slots;
892 if (self->type_index_ >= begin && self->type_index_ < end)
return true;
894 if (self->type_index_ == begin)
return true;
896 if (!TargetType::_type_child_slots_can_overflow)
return false;
898 if (self->type_index_ < TargetType::RuntimeTypeIndex())
return false;
900 return self->DerivedFrom(TargetType::RuntimeTypeIndex());
909 template <
typename ObjectType,
typename>
911 if (
data_ !=
nullptr &&
data_->IsInstance<ObjectType>()) {
912 return static_cast<ObjectType*
>(
data_.get());
918 template <
typename RefType,
typename ObjType>
919 inline RefType
GetRef(
const ObjType* ptr) {
920 static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
921 "Can only cast to the ref of same container type");
922 if (!RefType::_type_is_nullable) {
923 ICHECK(ptr !=
nullptr);
928 template <
typename BaseType,
typename ObjType>
930 static_assert(std::is_base_of<BaseType, ObjType>::value,
931 "Can only cast to the ref of same container type");
935 template <
typename SubRef,
typename BaseRef>
938 ICHECK(ref->template IsInstance<typename SubRef::ContainerType>())
939 <<
"Downcast from " << ref->GetTypeKey() <<
" to " << SubRef::ContainerType::_type_key
942 ICHECK(SubRef::_type_is_nullable) <<
"Downcast from nullptr to not nullable reference of "
943 << SubRef::ContainerType::_type_key;
945 return SubRef(std::move(ref.data_));
Managed reference to RelayRefTypeNode.
Definition: type.h:577
Base class of object allocators that implements make. Use curiously recurring template pattern.
Definition: memory.h:60
A custom smart pointer for Object.
Definition: object.h:362
void swap(ObjectPtr< T > &other)
Swap this array with another Object.
Definition: object.h:409
T * get() const
Definition: object.h:415
friend class Object
Definition: object.h:496
bool operator!=(const ObjectPtr< T > &other) const
Definition: object.h:466
ObjectPtr()
default constructor
Definition: object.h:365
friend ObjectPtr< BaseType > GetObjectPtr(ObjType *ptr)
Definition: object.h:929
friend class ObjectPtr
Definition: object.h:500
T * operator->() const
Definition: object.h:419
friend RelayRefType GetRef(const ObjType *ptr)
Definition: object.h:919
ObjectPtr(std::nullptr_t)
default constructor
Definition: object.h:367
ObjectPtr< T > & operator=(ObjectPtr< T > &&other)
move assignment
Definition: object.h:442
ObjectPtr(const ObjectPtr< T > &other)
copy constructor
Definition: object.h:372
ObjectPtr(ObjectPtr< Y > &&other)
move constructor
Definition: object.h:397
int use_count() const
Definition: object.h:460
bool operator==(const ObjectPtr< T > &other) const
Definition: object.h:464
ObjectPtr(ObjectPtr< T > &&other)
move constructor
Definition: object.h:388
void reset()
reset the content of ptr to be nullptr
Definition: object.h:453
~ObjectPtr()
destructor
Definition: object.h:404
T & operator*() const
Definition: object.h:423
ObjectPtr(const ObjectPtr< U > &other)
copy constructor
Definition: object.h:379
bool operator==(std::nullptr_t null) const
Definition: object.h:468
bool operator!=(std::nullptr_t null) const
Definition: object.h:470
bool unique() const
Definition: object.h:462
ObjectPtr< T > & operator=(const ObjectPtr< T > &other)
copy assignment
Definition: object.h:431
Base class of all object reference.
Definition: object.h:519
int use_count() const
Definition: object.h:560
bool defined() const
Definition: object.h:552
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:623
const Object * operator->() const
Definition: object.h:556
static constexpr bool _type_is_nullable
Definition: object.h:601
bool operator<(const ObjectRef &other) const
Comparator.
Definition: object.h:548
friend class ObjectInternal
Definition: object.h:638
bool unique() const
Definition: object.h:558
friend SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:936
ObjectRef(ObjectPtr< Object > data)
Constructor from existing object ptr.
Definition: object.h:524
ObjectRef()=default
default constructor
bool operator!=(const ObjectRef &other) const
Comparator.
Definition: object.h:542
const Object * get() const
Definition: object.h:554
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
static T DowncastNoCheck(ObjectRef ref)
Internal helper function downcast a ref without check.
Definition: object.h:615
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
Object * get_mutable() const
Definition: object.h:607
static ObjectPtr< ObjectType > GetDataPtr(const ObjectRef &ref)
Internal helper function get data_ as ObjectPtr of ObjectType.
Definition: object.h:631
bool operator==(const ObjectRef &other) const
Comparator.
Definition: object.h:536
base class of all object containers.
Definition: object.h:171
RefCounterType ref_counter_
The internal reference counter.
Definition: object.h:267
Object()
Definition: object.h:245
uint32_t type_index() const
Definition: object.h:179
uint32_t type_index_
Type index(tag) that indicates the type of the object.
Definition: object.h:265
std::string GetTypeKey() const
Definition: object.h:184
std::atomic< int32_t > RefCounterType
Definition: object.h:221
size_t GetTypeKeyHash() const
Definition: object.h:188
static uint32_t _GetOrAllocRuntimeTypeIndex()
Definition: object.h:228
static uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
Object & operator=(const Object &other)
Definition: object.h:255
static size_t TypeIndex2KeyHash(uint32_t tindex)
Get the type key hash of the corresponding index from runtime.
void DecRef()
developer function, decrease reference counter.
Definition: object.h:846
static uint32_t GetOrAllocRuntimeTypeIndex(const std::string &key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow)
Get the type index using type key.
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
bool IsInstance() const
Definition: object.h:874
Object(Object &&other)
Definition: object.h:253
Object(const Object &other)
Definition: object.h:251
void IncRef()
developer function, increases reference counter.
Definition: object.h:844
static uint32_t RuntimeTypeIndex()
Definition: object.h:229
Object & operator=(Object &&other)
Definition: object.h:258
FDeleter deleter_
deleter of this object to enable customized allocation. If the deleter is nullptr,...
Definition: object.h:273
bool unique() const
Definition: object.h:907
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
Definition: packed_func.h:1824
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:856
Internal base class to handle conversion to POD values.
Definition: packed_func.h:615
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
ObjectPtr< BaseType > GetObjectPtr(ObjectType *ptr)
Get an object ptr type from a raw object ptr.
SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:936
RelayRefType GetRef(const ObjectType *ptr)
Get a reference type from a raw object ptr type.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
ObjectRef equal functor.
Definition: object.h:665
size_t operator()(const ObjectPtr< T > &a, const ObjectPtr< T > &b) const
Definition: object.h:669
bool operator()(const ObjectRef &a, const ObjectRef &b) const
Definition: object.h:666
ObjectRef hash functor.
Definition: object.h:655
size_t operator()(const ObjectPtr< T > &a) const
Definition: object.h:659
size_t operator()(const ObjectRef &a) const
Definition: object.h:656
Namespace for the list of type index.
Definition: object.h:55
@ kRoot
Root object type.
Definition: object.h:58
@ kRuntimePackedFunc
runtime::PackedFunc.
Definition: object.h:74
@ kRuntimeRPCObjectRef
runtime::RPCObjectRef
Definition: object.h:78
@ kStaticIndexEnd
Definition: object.h:82
@ kRuntimeNDArray
runtime::NDArray.
Definition: object.h:64
@ kRuntimeMap
runtime::Map.
Definition: object.h:70
@ kRuntimeClosure
Definition: object.h:80
@ kRuntimeADT
Definition: object.h:81
@ kRuntimeString
runtime::String.
Definition: object.h:66
@ kDynamic
Type index is allocated during runtime.
Definition: object.h:84
@ kRuntimeArray
runtime::Array.
Definition: object.h:68
@ kRuntimeModule
runtime::Module.
Definition: object.h:62
@ kRuntimeDiscoDRef
runtime::DRef for disco distributed runtime
Definition: object.h:76
@ kRuntimeShapeTuple
runtime::ShapeTuple.
Definition: object.h:72