tvm_ffi/
object.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use std::ops::{Deref, DerefMut};
20use std::sync::atomic::AtomicU64;
21
22use crate::derive::ObjectRef;
23pub use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
24/// Object related ABI handling
25use tvm_ffi_sys::{TVMFFIObject, COMBINED_REF_COUNT_BOTH_ONE};
26
27/// Object type is by default the TVMFFIObject
28#[repr(C)]
29pub struct Object {
30    /// example implementation of the object
31    header: TVMFFIObject,
32}
33
34/// Arc-like wrapper for Object that allows shared ownership
35///
36/// \tparam T The type of the object to be wrapped
37#[repr(C)]
38pub struct ObjectArc<T: ObjectCore> {
39    ptr: std::ptr::NonNull<T>,
40    _phantom: std::marker::PhantomData<T>,
41}
42
43unsafe impl<T: Send + Sync + ObjectCore> Send for ObjectArc<T> {}
44unsafe impl<T: Send + Sync + ObjectCore> Sync for ObjectArc<T> {}
45
46/// Traits that can be used to check if a type is an object
47///
48/// This trait is unsafe because it is used to access the object header
49/// and the object header is unsafe to access
50pub unsafe trait ObjectCore: Sized + 'static {
51    /// the type key of the object
52    const TYPE_KEY: &'static str;
53    // return the type index of the object
54    fn type_index() -> i32;
55    /// Return the object header
56    /// This function is implemented as a static function so
57    ///
58    /// # Arguments
59    /// * `this` - The object to get the header
60    ///
61    /// # Returns
62    /// * `&mut TVMFFIObject` - The object header
63    /// \return The object header
64    unsafe fn object_header_mut(this: &mut Self) -> &mut TVMFFIObject;
65}
66
67/// Traits for objects with extra items that follows the object
68///
69/// This extra trait can be helpful to implement array types and string types
70pub unsafe trait ObjectCoreWithExtraItems: ObjectCore {
71    /// type of extra items storage that follows the object
72    type ExtraItem;
73    /// Return the number of extra items
74    fn extra_items_count(this: &Self) -> usize;
75    /// Return the extra items data pointer
76    unsafe fn extra_items(this: &Self) -> &[Self::ExtraItem] {
77        let extra_items_ptr = (this as *const Self as *const u8).add(std::mem::size_of::<Self>());
78        std::slice::from_raw_parts(
79            extra_items_ptr as *const Self::ExtraItem,
80            Self::extra_items_count(this),
81        )
82    }
83    /// Return the extra items data pointer
84    unsafe fn extra_items_mut(this: &mut Self) -> &mut [Self::ExtraItem] {
85        let extra_items_ptr = (this as *mut Self as *mut u8).add(std::mem::size_of::<Self>());
86        std::slice::from_raw_parts_mut(
87            extra_items_ptr as *mut Self::ExtraItem,
88            Self::extra_items_count(this),
89        )
90    }
91}
92
93/// Traits to specify core operations of ObjectRef
94///
95/// used by the ffi Any system and not user facing
96///
97/// We mark as unsafe since it moves out the internal of the ObjectRef
98pub unsafe trait ObjectRefCore: Sized + Clone {
99    type ContainerType: ObjectCore;
100    fn data(this: &Self) -> &ObjectArc<Self::ContainerType>;
101    fn into_data(this: Self) -> ObjectArc<Self::ContainerType>;
102    fn from_data(data: ObjectArc<Self::ContainerType>) -> Self;
103}
104
105/// Base class for ObjectRef
106///
107/// This class is used to store the data of the ObjectRef
108#[repr(C)]
109#[derive(ObjectRef, Clone)]
110pub struct ObjectRef {
111    data: ObjectArc<Object>,
112}
113
114/// Unsafe operations on object
115pub(crate) mod unsafe_ {
116    use tvm_ffi_sys::{
117        COMBINED_REF_COUNT_BOTH_ONE, COMBINED_REF_COUNT_MASK_U32, COMBINED_REF_COUNT_STRONG_ONE,
118        COMBINED_REF_COUNT_WEAK_ONE,
119    };
120
121    use std::ffi::c_void;
122    use std::sync::atomic::{fence, Ordering};
123    use tvm_ffi_sys::TVMFFIObject;
124    use tvm_ffi_sys::TVMFFIObjectDeleterFlagBitMask::{
125        kTVMFFIObjectDeleterFlagBitMaskBoth, kTVMFFIObjectDeleterFlagBitMaskStrong,
126        kTVMFFIObjectDeleterFlagBitMaskWeak,
127    };
128
129    /// Increase the strong reference count of the object
130    ///
131    /// This function is same as TVMFFIObjectIncRef but implemented natively in Rust
132    ///
133    /// # Arguments
134    /// * `obj` - The object to increase the reference count
135    #[inline]
136    pub unsafe fn inc_ref(handle: *mut TVMFFIObject) {
137        let obj = &mut *handle;
138        obj.combined_ref_count.fetch_add(1, Ordering::Relaxed);
139    }
140
141    /// Decrease the strong reference count of the object
142    ///
143    /// This function is same as TVMFFIObjectDecRef but implemented natively in Rust
144    ///
145    /// # Arguments
146    /// * `obj` - The object to decrease the reference count
147    #[inline]
148    pub unsafe fn dec_ref(handle: *mut TVMFFIObject) {
149        let obj = &mut *handle;
150        let old_combined_count = obj
151            .combined_ref_count
152            .fetch_sub(COMBINED_REF_COUNT_STRONG_ONE, Ordering::Relaxed);
153        if old_combined_count == COMBINED_REF_COUNT_BOTH_ONE {
154            if let Some(deleter) = obj.deleter {
155                fence(Ordering::Acquire);
156                deleter(
157                    obj as *mut TVMFFIObject as *mut c_void,
158                    kTVMFFIObjectDeleterFlagBitMaskBoth as i32,
159                );
160            }
161        } else if (old_combined_count & COMBINED_REF_COUNT_MASK_U32)
162            == COMBINED_REF_COUNT_STRONG_ONE
163        {
164            // slow path, there is still a weak reference left
165            // need to run two phase decrement
166            fence(Ordering::Acquire);
167            if let Some(deleter) = obj.deleter {
168                deleter(
169                    obj as *mut TVMFFIObject as *mut c_void,
170                    kTVMFFIObjectDeleterFlagBitMaskStrong as i32,
171                );
172            }
173            let old_weak_count = obj
174                .combined_ref_count
175                .fetch_sub(COMBINED_REF_COUNT_WEAK_ONE, Ordering::Release);
176            if old_weak_count == COMBINED_REF_COUNT_WEAK_ONE {
177                fence(Ordering::Acquire);
178                if let Some(deleter) = obj.deleter {
179                    deleter(
180                        obj as *mut TVMFFIObject as *mut c_void,
181                        kTVMFFIObjectDeleterFlagBitMaskWeak as i32,
182                    );
183                }
184            }
185        }
186    }
187
188    #[inline]
189    pub unsafe fn strong_count(handle: *mut TVMFFIObject) -> usize {
190        let obj = &mut *handle;
191        (obj.combined_ref_count.load(Ordering::Relaxed) & COMBINED_REF_COUNT_MASK_U32) as usize
192    }
193
194    #[inline]
195    pub unsafe fn weak_count(handle: *mut TVMFFIObject) -> usize {
196        let obj = &mut *handle;
197        (obj.combined_ref_count.load(Ordering::Relaxed) >> 32) as usize
198    }
199
200    /// Generic object deleter that works for object allocated from Box then into_raw
201    pub unsafe extern "C" fn object_deleter_for_new<T>(ptr: *mut c_void, flags: i32)
202    where
203        T: super::ObjectCore,
204    {
205        let obj = ptr as *mut T;
206        if flags & kTVMFFIObjectDeleterFlagBitMaskStrong as i32 != 0 {
207            // calling destructor of the object, does not free the memory
208            std::ptr::drop_in_place(obj);
209        }
210        if flags & kTVMFFIObjectDeleterFlagBitMaskWeak as i32 != 0 {
211            // free the memory
212            std::alloc::dealloc(ptr as *mut u8, std::alloc::Layout::new::<T>());
213        }
214    }
215
216    pub unsafe extern "C" fn object_deleter_for_new_with_extra_items<T, U>(
217        ptr: *mut c_void,
218        flags: i32,
219    ) where
220        T: super::ObjectCoreWithExtraItems<ExtraItem = U>,
221    {
222        let obj = ptr as *mut T;
223        if flags == kTVMFFIObjectDeleterFlagBitMaskBoth as i32 {
224            // must get extra items count before dropping the object
225            let extra_items_count = T::extra_items_count(&(*obj));
226            std::ptr::drop_in_place(obj);
227            let layout = std::alloc::Layout::from_size_align(
228                std::mem::size_of::<T>() + extra_items_count * std::mem::size_of::<U>(),
229                std::mem::align_of::<T>(),
230            )
231            .unwrap();
232            // free the memory
233            std::alloc::dealloc(ptr as *mut u8, layout);
234        } else {
235            assert_eq!(std::mem::size_of::<T>() % std::mem::size_of::<u64>(), 0);
236            if flags & kTVMFFIObjectDeleterFlagBitMaskStrong as i32 != 0 {
237                // must get extra items count before dropping the object
238                let extra_items_count = T::extra_items_count(&(*obj));
239                // calling destructor of the object, does not free the memory
240                std::ptr::drop_in_place(obj);
241                // record extra count in the original memory
242                std::ptr::write(obj as *mut u64, extra_items_count as u64);
243            }
244            if flags & kTVMFFIObjectDeleterFlagBitMaskWeak as i32 != 0 {
245                // read extra items count from the original memory
246                // note we can no longer read it by calling T::extra_items_count(&(*obj))
247                // because the object is already dropped
248                let extra_items_count = std::ptr::read(obj as *mut u64) as usize;
249                let layout = std::alloc::Layout::from_size_align(
250                    std::mem::size_of::<T>() + extra_items_count * std::mem::size_of::<U>(),
251                    std::mem::align_of::<T>(),
252                )
253                .unwrap();
254                // free the memory
255                std::alloc::dealloc(ptr as *mut u8, layout);
256            }
257        }
258    }
259}
260
261//---------------------
262// Object
263//---------------------
264
265impl Object {
266    pub fn new() -> Self {
267        Self {
268            header: TVMFFIObject::new(),
269        }
270    }
271}
272
273unsafe impl ObjectCore for Object {
274    const TYPE_KEY: &'static str = "ffi.Object";
275    #[inline]
276    fn type_index() -> i32 {
277        TypeIndex::kTVMFFIStaticObjectBegin as i32
278    }
279    #[inline]
280    unsafe fn object_header_mut(this: &mut Self) -> &mut TVMFFIObject {
281        &mut this.header
282    }
283}
284
285//---------------------
286// ObjectArc
287//---------------------
288impl<T: ObjectCore> ObjectArc<T> {
289    pub fn new(data: T) -> Self {
290        unsafe {
291            let layout = std::alloc::Layout::new::<T>();
292            let raw_data_ptr = std::alloc::alloc(layout);
293            if raw_data_ptr.is_null() {
294                std::alloc::handle_alloc_error(layout);
295            }
296            let ptr = raw_data_ptr as *mut T;
297            std::ptr::write(ptr, data);
298            // now override the header directly
299            std::ptr::write(
300                ptr as *mut TVMFFIObject,
301                TVMFFIObject {
302                    combined_ref_count: AtomicU64::new(COMBINED_REF_COUNT_BOTH_ONE),
303                    type_index: T::type_index(),
304                    __padding: 0,
305                    deleter: Some(unsafe_::object_deleter_for_new::<T>),
306                },
307            );
308            // move into the object arc ptr
309            Self {
310                ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T),
311                _phantom: std::marker::PhantomData,
312            }
313        }
314    }
315    pub fn new_with_extra_items<U>(data: T) -> Self
316    where
317        T: ObjectCoreWithExtraItems<ExtraItem = U>,
318    {
319        unsafe {
320            // ensure strict alignment requirements
321            // so we can have { T, U*extra_items } layout
322            assert_eq!(std::mem::align_of::<T>() % std::mem::align_of::<U>(), 0);
323            assert_eq!(std::mem::size_of::<T>() % std::mem::align_of::<U>(), 0);
324            let extra_items_count = T::extra_items_count(&data);
325            let layout = std::alloc::Layout::from_size_align(
326                std::mem::size_of::<T>() + extra_items_count * std::mem::size_of::<U>(),
327                std::mem::align_of::<T>(),
328            )
329            .unwrap();
330            let raw_data_ptr = std::alloc::alloc(layout);
331            if raw_data_ptr.is_null() {
332                std::alloc::handle_alloc_error(layout);
333            }
334            let ptr = raw_data_ptr as *mut T;
335            std::ptr::write(ptr, data);
336            // now override the header directly
337            std::ptr::write(
338                ptr as *mut TVMFFIObject,
339                TVMFFIObject {
340                    combined_ref_count: AtomicU64::new(COMBINED_REF_COUNT_BOTH_ONE),
341                    type_index: T::type_index(),
342                    __padding: 0,
343                    deleter: Some(unsafe_::object_deleter_for_new_with_extra_items::<T, U>),
344                },
345            );
346            // move into the object arc ptr
347            Self {
348                ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T),
349                _phantom: std::marker::PhantomData,
350            }
351        }
352    }
353
354    /// Move a previously allocated object into the ObjectArc
355    ///
356    /// # Arguments
357    /// * `ptr` - The raw pointer to move into the ObjectArc
358    ///
359    /// # Returns
360    /// * `ObjectArc<T>` - The ObjectArc
361    /// \return The ObjectArc
362    #[inline]
363    pub unsafe fn from_raw(ptr: *const T) -> Self {
364        Self {
365            ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T),
366            _phantom: std::marker::PhantomData,
367        }
368    }
369
370    /// Move the ObjectArc into a raw pointer
371    ///
372    /// # Arguments
373    /// * `this` - The ObjectArc to move into a raw pointer
374    ///
375    /// # Returns
376    /// * `*const T` - The raw pointer
377    #[inline]
378    pub unsafe fn into_raw(this: Self) -> *const T {
379        let droped_this = std::mem::ManuallyDrop::new(this);
380        droped_this.ptr.as_ptr() as *const T
381    }
382
383    /// Get the raw pointer from the ObjectArc
384    ///
385    /// Caller should view this as a non-owning reference
386    ///
387    /// # Arguments
388    /// * `this` - The ObjectArc to get the raw pointer
389    ///
390    /// # Returns
391    /// * `*const T` - The raw pointer
392    /// \return The raw pointer
393    #[inline]
394    pub unsafe fn as_raw(this: &Self) -> *const T {
395        this.ptr.as_ptr() as *const T
396    }
397
398    /// Get the raw mutable pointer from the ObjectArc
399    ///
400    /// Caller should view this as a non-owning reference
401    ///
402    /// # Arguments
403    /// * `this` - The ObjectArc to get the raw pointer
404    ///
405    /// # Returns
406    /// * `*mut T` - The raw pointer
407    #[inline]
408    pub unsafe fn as_raw_mut(this: &mut Self) -> *mut T {
409        this.ptr.as_mut()
410    }
411
412    /// Get the strong reference count of the ObjectArc
413    ///
414    /// # Arguments
415    /// * `this` - The ObjectArc to get the strong reference count
416    ///
417    /// # Returns
418    /// * `usize` - The strong reference count
419    #[inline]
420    pub fn strong_count(this: &Self) -> usize {
421        unsafe {
422            unsafe_::strong_count(this.ptr.as_ref() as *const T as *mut T as *mut TVMFFIObject)
423        }
424    }
425
426    /// Get the weak reference count of the ObjectArc
427    ///
428    /// # Arguments
429    /// * `this` - The ObjectArc to get the weak reference count
430    ///
431    /// # Returns
432    /// * `usize` - The weak reference count
433    #[inline]
434    pub fn weak_count(this: &Self) -> usize {
435        unsafe { unsafe_::weak_count(this.ptr.as_ref() as *const T as *mut T as *mut TVMFFIObject) }
436    }
437}
438
439// implement Deref for ObjectArc
440impl<T: ObjectCore> Deref for ObjectArc<T> {
441    type Target = T;
442    #[inline]
443    fn deref(&self) -> &Self::Target {
444        unsafe { self.ptr.as_ref() }
445    }
446}
447
448// implement DerefMut for ObjectArc
449impl<T: ObjectCore> DerefMut for ObjectArc<T> {
450    #[inline]
451    fn deref_mut(&mut self) -> &mut Self::Target {
452        unsafe { self.ptr.as_mut() }
453    }
454}
455
456// implement Drop for ObjectArc
457impl<T: ObjectCore> Drop for ObjectArc<T> {
458    fn drop(&mut self) {
459        unsafe { unsafe_::dec_ref(self.ptr.as_mut() as *mut T as *mut TVMFFIObject) }
460    }
461}
462
463// implement Clone for ObjectArc
464impl<T: ObjectCore> Clone for ObjectArc<T> {
465    #[inline]
466    fn clone(&self) -> Self {
467        unsafe { unsafe_::inc_ref(self.ptr.as_ref() as *const T as *mut T as *mut TVMFFIObject) }
468        Self {
469            ptr: self.ptr,
470            _phantom: std::marker::PhantomData,
471        }
472    }
473}