tvm_ffi/
string.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::derive::Object;
20use crate::object::{unsafe_, Object, ObjectArc, ObjectCoreWithExtraItems};
21use crate::type_traits::AnyCompatible;
22use std::cmp::Ordering;
23use std::fmt::{Debug, Display};
24use std::hash::{Hash, Hasher};
25use std::ops::Deref;
26use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
27use tvm_ffi_sys::{TVMFFIAny, TVMFFIAnyDataUnion, TVMFFIByteArray, TVMFFIObject};
28
29//-----------------------------------------------------
30// Bytes
31//-----------------------------------------------------
32/// ABI stable Bytes container for ffi
33#[repr(C)]
34pub struct Bytes {
35    data: TVMFFIAny,
36}
37
38// BytesObj for heap-allocated bytes
39#[repr(C)]
40#[derive(Object)]
41#[type_key = "ffi.Bytes"]
42#[type_index(TypeIndex::kTVMFFIBytes)]
43pub(crate) struct BytesObj {
44    object: Object,
45    data: TVMFFIByteArray,
46}
47
48impl Bytes {
49    /// Create a new empty Bytes container
50    pub fn new() -> Self {
51        Self {
52            data: TVMFFIAny {
53                type_index: TypeIndex::kTVMFFISmallBytes as i32,
54                small_str_len: 0,
55                data_union: TVMFFIAnyDataUnion { v_int64: 0 },
56            },
57        }
58    }
59
60    /// Get the length of the bytes
61    pub fn len(&self) -> usize {
62        self.as_slice().len()
63    }
64
65    /// Get the bytes as a slice
66    pub fn as_slice(&self) -> &[u8] {
67        unsafe {
68            if self.data.type_index == TypeIndex::kTVMFFISmallBytes as i32 {
69                std::slice::from_raw_parts(
70                    self.data.data_union.v_bytes.as_ptr(),
71                    self.data.small_str_len as usize,
72                )
73            } else {
74                let str_obj: &BytesObj = &*(self.data.data_union.v_obj as *const BytesObj);
75                std::slice::from_raw_parts(str_obj.data.data, str_obj.data.size)
76            }
77        }
78    }
79}
80
81impl Default for Bytes {
82    #[inline]
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88unsafe impl ObjectCoreWithExtraItems for BytesObj {
89    type ExtraItem = u8;
90    // extra item is the trailing \0 for ffi compatibility
91    #[inline]
92    /// Get the count of extra items (trailing null byte for FFI compatibility)
93    fn extra_items_count(this: &Self) -> usize {
94        return this.data.size + 1;
95    }
96}
97
98impl<T> From<T> for Bytes
99where
100    T: AsRef<[u8]>,
101{
102    #[inline]
103    /// Create Bytes from any type that can be converted to a byte slice
104    fn from(src: T) -> Self {
105        let value: &[u8] = src.as_ref();
106        // to be compatible with normal c++
107        const MAX_SMALL_BYTES_LEN: usize = 7;
108        unsafe {
109            if value.len() <= MAX_SMALL_BYTES_LEN {
110                let mut data_union = TVMFFIAnyDataUnion { v_int64: 0 };
111                data_union.v_bytes[..value.len()].copy_from_slice(value);
112                // small bytes
113                Self {
114                    data: TVMFFIAny {
115                        type_index: TypeIndex::kTVMFFISmallBytes as i32,
116                        small_str_len: value.len() as u32,
117                        data_union: data_union,
118                    },
119                }
120            } else {
121                // large bytes
122                let mut obj_arc = ObjectArc::new_with_extra_items(BytesObj {
123                    object: Object::new(),
124                    data: TVMFFIByteArray {
125                        data: std::ptr::null(),
126                        size: value.len(),
127                    },
128                });
129                // reset the data ptr correctly after Arc is created
130                obj_arc.data.data = BytesObj::extra_items(&obj_arc).as_ptr();
131                let extra_items = BytesObj::extra_items_mut(&mut obj_arc);
132                extra_items[..value.len()].copy_from_slice(value);
133                // write the trailing \0 for ffi compatibility
134                extra_items[value.len()] = 0;
135                Self {
136                    data: TVMFFIAny {
137                        type_index: TypeIndex::kTVMFFIBytes as i32,
138                        small_str_len: 0,
139                        data_union: TVMFFIAnyDataUnion {
140                            v_obj: ObjectArc::into_raw(obj_arc) as *mut BytesObj
141                                as *mut TVMFFIObject,
142                        },
143                    },
144                }
145            }
146        }
147    }
148}
149
150impl Deref for Bytes {
151    type Target = [u8];
152    #[inline]
153    fn deref(&self) -> &[u8] {
154        self.as_slice()
155    }
156}
157
158impl Clone for Bytes {
159    #[inline]
160    fn clone(&self) -> Self {
161        if self.data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
162            unsafe { unsafe_::inc_ref(self.data.data_union.v_obj) }
163        }
164        Self { data: self.data }
165    }
166}
167
168impl Drop for Bytes {
169    #[inline]
170    fn drop(&mut self) {
171        if self.data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
172            unsafe { unsafe_::dec_ref(self.data.data_union.v_obj) }
173        }
174    }
175}
176
177impl PartialEq for Bytes {
178    #[inline]
179    fn eq(&self, other: &Self) -> bool {
180        self.as_slice() == other.as_slice()
181    }
182}
183
184impl Eq for Bytes {}
185
186impl PartialOrd for Bytes {
187    #[inline]
188    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
189        self.as_slice().partial_cmp(other.as_slice())
190    }
191}
192
193impl Ord for Bytes {
194    #[inline]
195    fn cmp(&self, other: &Self) -> Ordering {
196        self.as_slice().cmp(other.as_slice())
197    }
198}
199
200impl Hash for Bytes {
201    #[inline]
202    fn hash<H: Hasher>(&self, state: &mut H) {
203        self.as_slice().hash(state);
204    }
205}
206
207impl Display for Bytes {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        Debug::fmt(&self.as_slice(), f)
210    }
211}
212
213impl Debug for Bytes {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("ffi.Bytes")
216            .field("data", &self.as_slice())
217            .finish()
218    }
219}
220
221//-----------------------------------------------------
222// String
223//-----------------------------------------------------
224
225/// ABI stable String container for ffi
226#[repr(C)]
227pub struct String {
228    data: TVMFFIAny,
229}
230
231// StringObj for heap-allocated string
232#[repr(C)]
233#[derive(Object)]
234#[type_key = "ffi.String"]
235#[type_index(TypeIndex::kTVMFFIStr)]
236pub(crate) struct StringObj {
237    object: Object,
238    data: TVMFFIByteArray,
239}
240
241unsafe impl ObjectCoreWithExtraItems for StringObj {
242    type ExtraItem = u8;
243    #[inline]
244    /// Get the count of extra items (trailing null byte for FFI compatibility)
245    fn extra_items_count(this: &Self) -> usize {
246        // extra item is the trailing \0 for ffi compatibility
247        return this.data.size + 1;
248    }
249}
250
251impl String {
252    /// Create a new empty String container
253    pub fn new() -> Self {
254        Self {
255            data: TVMFFIAny {
256                type_index: TypeIndex::kTVMFFISmallStr as i32,
257                small_str_len: 0,
258                data_union: TVMFFIAnyDataUnion { v_int64: 0 },
259            },
260        }
261    }
262
263    /// Get the length of the string in bytes
264    pub fn len(&self) -> usize {
265        self.as_bytes().len()
266    }
267
268    /// Get the string as a byte slice
269    pub fn as_bytes(&self) -> &[u8] {
270        unsafe {
271            if self.data.type_index == TypeIndex::kTVMFFISmallStr as i32 {
272                std::slice::from_raw_parts(
273                    self.data.data_union.v_bytes.as_ptr(),
274                    self.data.small_str_len as usize,
275                )
276            } else {
277                let str_obj: &StringObj = &*(self.data.data_union.v_obj as *const StringObj);
278                std::slice::from_raw_parts(str_obj.data.data, str_obj.data.size)
279            }
280        }
281    }
282
283    /// Get the string as a str slice
284    pub fn as_str(&self) -> &str {
285        unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
286    }
287}
288
289impl<T> From<T> for String
290where
291    T: AsRef<str>,
292{
293    #[inline]
294    /// Create String from any type that can be converted to a string slice
295    fn from(src: T) -> Self {
296        unsafe {
297            let value: &str = src.as_ref();
298            let bytes = value.as_bytes();
299            const MAX_SMALL_BYTES_LEN: usize = 7;
300            if bytes.len() <= MAX_SMALL_BYTES_LEN {
301                let mut data_union = TVMFFIAnyDataUnion { v_int64: 0 };
302                data_union.v_bytes[..bytes.len()].copy_from_slice(bytes);
303                Self {
304                    data: TVMFFIAny {
305                        type_index: TypeIndex::kTVMFFISmallStr as i32,
306                        small_str_len: bytes.len() as u32,
307                        data_union: data_union,
308                    },
309                }
310            } else {
311                let mut obj_arc = ObjectArc::new_with_extra_items(StringObj {
312                    object: Object::new(),
313                    data: TVMFFIByteArray {
314                        data: std::ptr::null(),
315                        size: bytes.len(),
316                    },
317                });
318                obj_arc.data.data = StringObj::extra_items(&obj_arc).as_ptr();
319                let extra_items = StringObj::extra_items_mut(&mut obj_arc);
320                extra_items[..bytes.len()].copy_from_slice(bytes);
321                // write the trailing \0 for ffi compatibility
322                extra_items[bytes.len()] = 0;
323                Self {
324                    data: TVMFFIAny {
325                        type_index: TypeIndex::kTVMFFIStr as i32,
326                        small_str_len: 0,
327                        data_union: TVMFFIAnyDataUnion {
328                            v_obj: ObjectArc::into_raw(obj_arc) as *mut StringObj
329                                as *mut TVMFFIObject,
330                        },
331                    },
332                }
333            }
334        }
335    }
336}
337
338impl Default for String {
339    #[inline]
340    fn default() -> Self {
341        Self::new()
342    }
343}
344
345impl Deref for String {
346    type Target = str;
347    #[inline]
348    fn deref(&self) -> &str {
349        self.as_str()
350    }
351}
352
353impl Clone for String {
354    #[inline]
355    fn clone(&self) -> Self {
356        if self.data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
357            unsafe { unsafe_::inc_ref(self.data.data_union.v_obj) }
358        }
359        Self { data: self.data }
360    }
361}
362
363impl Drop for String {
364    #[inline]
365    fn drop(&mut self) {
366        if self.data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
367            unsafe { unsafe_::dec_ref(self.data.data_union.v_obj) }
368        }
369    }
370}
371
372impl PartialEq for String {
373    #[inline]
374    fn eq(&self, other: &Self) -> bool {
375        self.as_bytes() == other.as_bytes()
376    }
377}
378
379// Allows `my_string == "hello"`
380impl<T> PartialEq<T> for String
381where
382    T: AsRef<str>,
383{
384    #[inline]
385    fn eq(&self, other: &T) -> bool {
386        self.as_str() == other.as_ref()
387    }
388}
389
390impl Eq for String {}
391
392impl<T> PartialEq<T> for Bytes
393where
394    T: AsRef<[u8]>,
395{
396    #[inline]
397    fn eq(&self, other: &T) -> bool {
398        self.as_slice() == other.as_ref()
399    }
400}
401
402impl PartialOrd for String {
403    #[inline]
404    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
405        self.as_str().partial_cmp(other.as_str())
406    }
407}
408
409impl Ord for String {
410    #[inline]
411    fn cmp(&self, other: &Self) -> Ordering {
412        self.as_str().cmp(other.as_str())
413    }
414}
415
416impl Hash for String {
417    #[inline]
418    fn hash<H: Hasher>(&self, state: &mut H) {
419        self.as_str().hash(state);
420    }
421}
422
423impl Display for String {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        Debug::fmt(&self.as_str(), f)
426    }
427}
428
429impl Debug for String {
430    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        f.debug_struct("ffi.String")
432            .field("data", &self.as_str())
433            .finish()
434    }
435}
436
437//-----------------------------------------------------
438// AnyCompatible implementation for Bytes and String
439//-----------------------------------------------------
440unsafe impl AnyCompatible for Bytes {
441    fn type_str() -> std::string::String {
442        "ffi.Bytes".to_string()
443    }
444
445    unsafe fn copy_to_any_view(this: &Self, data: &mut TVMFFIAny) {
446        *data = this.data;
447    }
448
449    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
450        *data = src.data;
451        std::mem::forget(src);
452    }
453
454    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
455        return data.type_index == TypeIndex::kTVMFFISmallBytes as i32
456            || data.type_index == TypeIndex::kTVMFFIBytes as i32;
457    }
458
459    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
460        if data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
461            unsafe { unsafe_::inc_ref(data.data_union.v_obj) }
462        }
463        Self { data: *data }
464    }
465
466    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
467        Self { data: *data }
468    }
469
470    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
471        if data.type_index == TypeIndex::kTVMFFIByteArrayPtr as i32 {
472            // deep copy from bytearray ptr
473            let bytes = &*(data.data_union.v_ptr as *const TVMFFIByteArray);
474            Ok(Self::from(std::slice::from_raw_parts(
475                bytes.data, bytes.size,
476            )))
477        } else if data.type_index == TypeIndex::kTVMFFISmallBytes as i32 {
478            Ok(Self { data: *data })
479        } else if data.type_index == TypeIndex::kTVMFFIBytes as i32 {
480            unsafe { unsafe_::inc_ref(data.data_union.v_obj) }
481            Ok(Self { data: *data })
482        } else {
483            Err(())
484        }
485    }
486}
487
488unsafe impl AnyCompatible for String {
489    fn type_str() -> std::string::String {
490        "ffi.String".to_string()
491    }
492
493    unsafe fn copy_to_any_view(this: &Self, data: &mut TVMFFIAny) {
494        *data = this.data;
495    }
496
497    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
498        *data = src.data;
499        std::mem::forget(src);
500    }
501
502    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
503        return data.type_index == TypeIndex::kTVMFFISmallStr as i32
504            || data.type_index == TypeIndex::kTVMFFIStr as i32;
505    }
506
507    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
508        if data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin as i32 {
509            unsafe { unsafe_::inc_ref(data.data_union.v_obj) }
510        }
511        Self { data: *data }
512    }
513
514    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
515        Self { data: *data }
516    }
517
518    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
519        if data.type_index == TypeIndex::kTVMFFIRawStr as i32 {
520            // 1. Create a CStr wrapper from the raw pointer.
521            let c_str =
522                std::ffi::CStr::from_ptr(data.data_union.v_c_str as *const std::os::raw::c_char);
523            Ok(Self::from(c_str.to_str().expect("Invalid UTF-8")))
524        } else if data.type_index == TypeIndex::kTVMFFISmallStr as i32 {
525            Ok(Self { data: *data })
526        } else if data.type_index == TypeIndex::kTVMFFIStr as i32 {
527            unsafe { unsafe_::inc_ref(data.data_union.v_obj) }
528            Ok(Self { data: *data })
529        } else {
530            Err(())
531        }
532    }
533}