tvm_ffi/collections/
shape.rs1use crate::derive::{Object, ObjectRef};
20use crate::object::{Object, ObjectArc, ObjectCoreWithExtraItems};
21use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
22use std::ops::Deref;
23use tvm_ffi_sys::TVMFFIShapeCell;
24use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
25
26#[repr(C)]
31#[derive(Object)]
32#[type_key = "ffi.Shape"]
33#[type_index(TypeIndex::kTVMFFIShape)]
34pub struct ShapeObj {
35 object: Object,
36 data: TVMFFIShapeCell,
37}
38
39#[repr(C)]
41#[derive(ObjectRef, Clone)]
42pub struct Shape {
43 data: ObjectArc<ShapeObj>,
44}
45
46impl Shape {
47 pub fn new() -> Self {
49 let shape_obj = ShapeObj {
50 object: Object::new(),
51 data: TVMFFIShapeCell {
52 data: std::ptr::null(),
53 size: 0,
54 },
55 };
56 Self {
57 data: ObjectArc::new(shape_obj),
58 }
59 }
60
61 pub fn as_slice(&self) -> &[i64] {
63 unsafe { std::slice::from_raw_parts(self.data.data.data, self.data.data.size) }
64 }
65
66 pub fn fill_strides_from_shape<T>(shape: T, strides: &mut [i64])
68 where
69 T: AsRef<[i64]>,
70 {
71 let shape = shape.as_ref();
72 let mut stride = 1;
73 for i in (0..shape.len()).rev() {
74 strides[i] = stride;
75 stride *= shape[i];
76 }
77 }
78}
79
80unsafe impl ObjectCoreWithExtraItems for ShapeObj {
81 type ExtraItem = i64;
82 fn extra_items_count(this: &Self) -> usize {
83 this.data.size
84 }
85}
86
87impl<T> From<T> for Shape
88where
89 T: AsRef<[i64]>,
90{
91 fn from(value: T) -> Self {
92 unsafe {
93 let value_slice: &[i64] = value.as_ref();
94 let mut obj_arc = ObjectArc::new_with_extra_items(ShapeObj {
95 object: Object::new(),
96 data: TVMFFIShapeCell {
97 data: std::ptr::null(),
98 size: value_slice.len(),
99 },
100 });
101 obj_arc.data.data = ShapeObj::extra_items(&obj_arc).as_ptr();
103 let extra_items = ShapeObj::extra_items_mut(&mut obj_arc);
104 extra_items.copy_from_slice(value_slice);
105 Self { data: obj_arc }
106 }
107 }
108}
109
110impl Deref for Shape {
111 type Target = [i64];
112 #[inline]
113 fn deref(&self) -> &[i64] {
114 self.as_slice()
115 }
116}
117
118impl PartialEq for Shape {
119 #[inline]
120 fn eq(&self, other: &Self) -> bool {
121 self.as_slice() == other.as_slice()
122 }
123}
124
125impl Eq for Shape {}
126
127impl PartialOrd for Shape {
128 #[inline]
129 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
130 self.as_slice().partial_cmp(other.as_slice())
131 }
132}
133
134impl Ord for Shape {
135 #[inline]
136 fn cmp(&self, other: &Self) -> Ordering {
137 self.as_slice().cmp(other.as_slice())
138 }
139}