tvm_ffi/collections/
array.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use std::fmt::Debug;
20use std::marker::PhantomData;
21use std::ops::Deref;
22
23use crate::any::TryFromTemp;
24use crate::derive::Object;
25use crate::object::{Object, ObjectArc};
26use crate::{Any, AnyCompatible, AnyView, ObjectCoreWithExtraItems, ObjectRefCore};
27use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
28use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject};
29
30#[repr(C)]
31#[derive(Object)]
32#[type_key = "ffi.Array"]
33#[type_index(TypeIndex::kTVMFFIArray)]
34pub struct ArrayObj {
35    pub object: Object,
36    /// Pointer to the start of the element buffer (AddressOf(0)).
37    pub data: *mut core::ffi::c_void,
38    pub size: i64,
39    pub capacity: i64,
40    /// Optional custom deleter for the data pointer.
41    pub data_deleter: Option<unsafe extern "C" fn(*mut core::ffi::c_void)>,
42}
43
44unsafe impl ObjectCoreWithExtraItems for ArrayObj {
45    type ExtraItem = TVMFFIAny;
46    fn extra_items_count(this: &Self) -> usize {
47        this.size as usize
48    }
49}
50
51#[repr(C)]
52#[derive(Clone)]
53pub struct Array<T: AnyCompatible + Clone> {
54    data: ObjectArc<ArrayObj>,
55    _marker: PhantomData<T>,
56}
57
58impl<T: AnyCompatible + Clone> Debug for Array<T> {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let full_name = std::any::type_name::<T>();
61        let short_name = full_name.split("::").last().unwrap_or(full_name);
62        write!(f, "Array<{}>[{}]", short_name, self.len())
63    }
64}
65
66impl<T: AnyCompatible + Clone> Default for Array<T> {
67    fn default() -> Self {
68        Self::new(vec![])
69    }
70}
71
72unsafe impl<T: AnyCompatible + Clone> ObjectRefCore for Array<T> {
73    type ContainerType = ArrayObj;
74
75    fn data(this: &Self) -> &ObjectArc<Self::ContainerType> {
76        &this.data
77    }
78
79    fn into_data(this: Self) -> ObjectArc<Self::ContainerType> {
80        this.data
81    }
82
83    fn from_data(data: ObjectArc<Self::ContainerType>) -> Self {
84        Self {
85            data,
86            _marker: PhantomData,
87        }
88    }
89}
90
91impl<T: AnyCompatible + Clone> Array<T> {
92    /// Creates a new Array from a vector of items.
93    pub fn new(items: Vec<T>) -> Self {
94        let capacity = items.len();
95        Self::new_with_capacity(items, capacity)
96    }
97
98    /// Internal helper to allocate an ArrayObj with specific headroom.
99    fn new_with_capacity(items: Vec<T>, capacity: usize) -> Self {
100        let size = items.len();
101
102        // Allocate with capacity
103        let arc = ObjectArc::<ArrayObj>::new_with_extra_items(ArrayObj {
104            object: Object::new(),
105            data: core::ptr::null_mut(),
106            size: size as i64,
107            capacity: capacity as i64,
108            data_deleter: None,
109        });
110
111        unsafe {
112            let raw_ptr = ObjectArc::as_raw(&arc) as *mut ArrayObj;
113            let container = &mut *raw_ptr;
114
115            let base_ptr = ArrayObj::extra_items_mut(container).as_ptr() as *mut TVMFFIAny;
116            container.data = base_ptr as *mut _;
117
118            for (i, item) in items.into_iter().enumerate() {
119                let any: Any = Any::from(item);
120                let raw = Any::into_raw_ffi_any(any);
121                core::ptr::write(base_ptr.add(i), raw);
122            }
123        }
124        Self::from_data(arc)
125    }
126
127    pub fn len(&self) -> usize {
128        self.data.size as usize
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.len() == 0
133    }
134
135    /// Retrieves an item at the given index.
136    pub fn get(&self, index: usize) -> Result<T, crate::Error> {
137        if index >= self.len() {
138            crate::bail!(crate::error::INDEX_ERROR, "Array get index out of bound");
139        }
140        unsafe {
141            let container = self.data.deref();
142            let base_ptr = container.data as *const TVMFFIAny;
143            let raw_any_ref = &*base_ptr.add(index);
144
145            match T::try_cast_from_any_view(raw_any_ref) {
146                Ok(val) => Ok(val),
147                Err(_) => crate::bail!(
148                    crate::error::TYPE_ERROR,
149                    "Failed to cast element at {} to {}",
150                    index,
151                    T::type_str()
152                ),
153            }
154        }
155    }
156
157    pub fn iter(&'_ self) -> ArrayIterator<'_, T> {
158        ArrayIterator {
159            array: self,
160            index: 0,
161            len: self.len(),
162        }
163    }
164
165    #[inline]
166    fn as_container(&self) -> &ArrayObj {
167        unsafe {
168            let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj;
169            &*ptr
170        }
171    }
172}
173
174// --- Index Implementation ---
175
176impl<T: AnyCompatible + Clone> std::ops::Index<usize> for Array<T> {
177    type Output = AnyView<'static>;
178
179    fn index(&self, index: usize) -> &Self::Output {
180        let container = self.as_container();
181        let len = container.size as usize;
182        if index >= len {
183            panic!(
184                "Index out of bounds: the len is {} but the index is {}",
185                len, index
186            );
187        }
188        unsafe {
189            let ptr = (container.data as *const AnyView<'static>).add(index);
190            &*ptr
191        }
192    }
193}
194
195// --- Iterator Implementations ---
196
197pub struct ArrayIterator<'a, T: AnyCompatible + Clone> {
198    array: &'a Array<T>,
199    index: usize,
200    len: usize,
201}
202
203impl<'a, T: AnyCompatible + Clone> Iterator for ArrayIterator<'a, T> {
204    type Item = T;
205
206    fn next(&mut self) -> Option<Self::Item> {
207        if self.index < self.len {
208            let item = self.array.get(self.index).ok();
209            self.index += 1;
210            item
211        } else {
212            None
213        }
214    }
215}
216
217impl<'a, T: AnyCompatible + Clone> IntoIterator for &'a Array<T> {
218    type Item = T;
219    type IntoIter = ArrayIterator<'a, T>;
220
221    fn into_iter(self) -> Self::IntoIter {
222        self.iter()
223    }
224}
225
226impl<T: AnyCompatible + Clone> FromIterator<T> for Array<T> {
227    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
228        let items: Vec<T> = iter.into_iter().collect();
229        Self::new(items)
230    }
231}
232
233// --- Any Type System Conversions ---
234
235unsafe impl<T> AnyCompatible for Array<T>
236where
237    T: AnyCompatible + Clone + 'static,
238{
239    fn type_str() -> String {
240        format!("Array<{}>", T::type_str())
241    }
242
243    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
244        if data.type_index != TypeIndex::kTVMFFIArray as i32 {
245            return false;
246        }
247
248        if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Any>() {
249            return true;
250        }
251
252        let container = &*(data.data_union.v_obj as *const ArrayObj);
253        let base_ptr = container.data as *const TVMFFIAny;
254        for i in 0..container.size {
255            let elem_any = &*base_ptr.add(i as usize);
256            if !T::check_any_strict(elem_any) {
257                return false;
258            }
259        }
260        true
261    }
262
263    unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
264        data.type_index = TypeIndex::kTVMFFIArray as i32;
265        data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut TVMFFIObject;
266        data.small_str_len = 0;
267    }
268
269    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
270        data.type_index = TypeIndex::kTVMFFIArray as i32;
271        data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as *mut TVMFFIObject;
272        data.small_str_len = 0;
273    }
274
275    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
276        let ptr = data.data_union.v_obj as *const ArrayObj;
277        crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject);
278        Self::from_data(ObjectArc::from_raw(ptr))
279    }
280
281    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
282        let ptr = data.data_union.v_obj as *const ArrayObj;
283        let obj = Self::from_data(ObjectArc::from_raw(ptr));
284
285        data.type_index = TypeIndex::kTVMFFINone as i32;
286        data.data_union.v_int64 = 0;
287
288        obj
289    }
290
291    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
292        if data.type_index != TypeIndex::kTVMFFIArray as i32 {
293            return Err(());
294        }
295
296        // Fast path: if types match exactly, we can just copy the reference.
297        if Self::check_any_strict(data) {
298            return Ok(Self::copy_from_any_view_after_check(data));
299        }
300
301        // Slow path: try to convert element by element.
302        let container = &*(data.data_union.v_obj as *const ArrayObj);
303        let base_ptr = container.data as *const TVMFFIAny;
304        let mut items = Vec::with_capacity(container.size as usize);
305
306        for i in 0..container.size {
307            let any_v = &*base_ptr.add(i as usize);
308            if let Ok(item) = T::try_cast_from_any_view(any_v) {
309                items.push(item);
310            } else {
311                return Err(());
312            }
313        }
314
315        Ok(Array::new(items))
316    }
317}
318
319impl<T> TryFrom<Any> for Array<T>
320where
321    T: AnyCompatible + Clone + 'static,
322{
323    type Error = crate::error::Error;
324
325    fn try_from(value: Any) -> Result<Self, Self::Error> {
326        let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
327        Ok(TryFromTemp::into_value(temp))
328    }
329}
330
331impl<'a, T> TryFrom<AnyView<'a>> for Array<T>
332where
333    T: AnyCompatible + Clone + 'static,
334{
335    type Error = crate::error::Error;
336
337    fn try_from(value: AnyView<'a>) -> Result<Self, Self::Error> {
338        let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
339        Ok(TryFromTemp::into_value(temp))
340    }
341}