tvm_ffi/collections/
shape.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::derive::{Object, ObjectRef};
20use crate::object::{Object, ObjectArc, ObjectCoreWithExtraItems};
21use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
22use std::ops::Deref;
23use tvm_ffi_sys::TVMFFIShapeCell;
24use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
25
26//-----------------------------------------------------
27// Shape
28//-----------------------------------------------------
29// ShapeObj for heap-allocated shape
30#[repr(C)]
31#[derive(Object)]
32#[type_key = "ffi.Shape"]
33#[type_index(TypeIndex::kTVMFFIShape)]
34pub struct ShapeObj {
35    object: Object,
36    data: TVMFFIShapeCell,
37}
38
39/// ABI stable owned Shape for ffi
40#[repr(C)]
41#[derive(ObjectRef, Clone)]
42pub struct Shape {
43    data: ObjectArc<ShapeObj>,
44}
45
46impl Shape {
47    /// Create a new empty Shape
48    pub fn new() -> Self {
49        let shape_obj = ShapeObj {
50            object: Object::new(),
51            data: TVMFFIShapeCell {
52                data: std::ptr::null(),
53                size: 0,
54            },
55        };
56        Self {
57            data: ObjectArc::new(shape_obj),
58        }
59    }
60
61    /// Get the shape as a slice
62    pub fn as_slice(&self) -> &[i64] {
63        unsafe { std::slice::from_raw_parts(self.data.data.data, self.data.data.size) }
64    }
65
66    /// Fill the strides from the shape
67    pub fn fill_strides_from_shape<T>(shape: T, strides: &mut [i64])
68    where
69        T: AsRef<[i64]>,
70    {
71        let shape = shape.as_ref();
72        let mut stride = 1;
73        for i in (0..shape.len()).rev() {
74            strides[i] = stride;
75            stride *= shape[i];
76        }
77    }
78}
79
80unsafe impl ObjectCoreWithExtraItems for ShapeObj {
81    type ExtraItem = i64;
82    fn extra_items_count(this: &Self) -> usize {
83        this.data.size
84    }
85}
86
87impl<T> From<T> for Shape
88where
89    T: AsRef<[i64]>,
90{
91    fn from(value: T) -> Self {
92        unsafe {
93            let value_slice: &[i64] = value.as_ref();
94            let mut obj_arc = ObjectArc::new_with_extra_items(ShapeObj {
95                object: Object::new(),
96                data: TVMFFIShapeCell {
97                    data: std::ptr::null(),
98                    size: value_slice.len(),
99                },
100            });
101            // reset the data ptr correctly after Arc is created
102            obj_arc.data.data = ShapeObj::extra_items(&obj_arc).as_ptr();
103            let extra_items = ShapeObj::extra_items_mut(&mut obj_arc);
104            extra_items.copy_from_slice(value_slice);
105            Self { data: obj_arc }
106        }
107    }
108}
109
110impl Deref for Shape {
111    type Target = [i64];
112    #[inline]
113    fn deref(&self) -> &[i64] {
114        self.as_slice()
115    }
116}
117
118impl PartialEq for Shape {
119    #[inline]
120    fn eq(&self, other: &Self) -> bool {
121        self.as_slice() == other.as_slice()
122    }
123}
124
125impl Eq for Shape {}
126
127impl PartialOrd for Shape {
128    #[inline]
129    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
130        self.as_slice().partial_cmp(other.as_slice())
131    }
132}
133
134impl Ord for Shape {
135    #[inline]
136    fn cmp(&self, other: &Self) -> Ordering {
137        self.as_slice().cmp(other.as_slice())
138    }
139}