tvm_ffi/collections/
tensor.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::collections::shape::Shape;
20use crate::derive::{Object, ObjectRef};
21use crate::dtype::AsDLDataType;
22use crate::dtype::DLDataTypeExt;
23use crate::error::Result;
24use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems};
25use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor};
26use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
27
28//-----------------------------------------------------
29// NDAllocator Trait
30//-----------------------------------------------------
31/// Trait for n-dimensional array allocators
32pub unsafe trait NDAllocator: 'static {
33    /// The minimum alignment of the data allocated by the allocator
34    const MIN_ALIGN: usize;
35    /// Allocate data for the given DLTensor
36    ///
37    /// # Arguments
38    /// * `tensor` - The DLTensor to allocate data for
39    ///
40    /// This method should fill in the data pointer of the DLTensor.
41    unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void;
42
43    /// Free data for the given DLTensor
44    ///
45    /// # Arguments
46    /// * `tensor` - The DLTensor to free data for
47    ///
48    /// This method should free the data pointer of the DLTensor.
49    unsafe fn free_data(&mut self, tensor: &DLTensor);
50}
51
52/// DLTensorExt trait
53/// This trait provides methods to get the number of elements and the item size of a DLTensor
54pub trait DLTensorExt {
55    fn numel(&self) -> usize;
56    fn item_size(&self) -> usize;
57}
58
59impl DLTensorExt for DLTensor {
60    fn numel(&self) -> usize {
61        unsafe {
62            std::slice::from_raw_parts(self.shape, self.ndim as usize)
63                .iter()
64                .product::<i64>() as usize
65        }
66    }
67
68    fn item_size(&self) -> usize {
69        (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8
70    }
71}
72
73//-----------------------------------------------------
74// Shape
75//-----------------------------------------------------
76// ShapeObj for heap-allocated shape
77#[repr(C)]
78#[derive(Object)]
79#[type_key = "ffi.Tensor"]
80#[type_index(TypeIndex::kTVMFFITensor)]
81pub struct TensorObj {
82    object: Object,
83    dltensor: DLTensor,
84}
85
86/// ABI stable owned Shape for ffi
87#[repr(C)]
88#[derive(ObjectRef, Clone)]
89pub struct Tensor {
90    data: ObjectArc<TensorObj>,
91}
92
93impl Tensor {
94    /// Get the data pointer of the Tensor
95    ///
96    /// # Returns
97    /// * `*mut core::ffi::c_void` - The data pointer of the Tensor
98    pub fn data_ptr(&self) -> *const core::ffi::c_void {
99        self.data.dltensor.data
100    }
101    /// Get the data pointer of the Tensor
102    ///
103    /// # Returns
104    /// * `*mut core::ffi::c_void` - The data pointer of the Tensor
105    pub fn data_ptr_mut(&mut self) -> *mut core::ffi::c_void {
106        self.data.dltensor.data
107    }
108    /// Check if the Tensor is contiguous
109    ///
110    /// # Returns
111    /// * `bool` - True if the Tensor is contiguous, false otherwise
112    pub fn is_contiguous(&self) -> bool {
113        let strides = self.strides();
114        let shape = self.shape();
115        let mut expected_stride = 1;
116        for i in (0..self.ndim()).rev() {
117            if strides[i] != expected_stride {
118                return false;
119            }
120            expected_stride *= shape[i];
121        }
122        true
123    }
124
125    pub fn data_as_slice<T: AsDLDataType>(&self) -> Result<&[T]> {
126        let dtype = T::DL_DATA_TYPE;
127        if self.dtype() != dtype {
128            crate::bail!(
129                crate::error::TYPE_ERROR,
130                "Data type mismatch {} vs {}",
131                self.dtype().to_string(),
132                dtype.to_string()
133            );
134        }
135        if self.device().device_type != DLDeviceType::kDLCPU {
136            crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
137        }
138        crate::ensure!(
139            self.is_contiguous(),
140            crate::error::RUNTIME_ERROR,
141            "Tensor is not contiguous"
142        );
143
144        unsafe {
145            Ok(std::slice::from_raw_parts(
146                self.data.dltensor.data as *const T,
147                self.numel(),
148            ))
149        }
150    }
151    /// Get the data as a mutable slice
152    ///
153    /// Note that we do allow mutable data access to copies of the Tensor,
154    /// as in the case of low-level deep learning frameworks.
155    ///
156    /// # Arguments
157    /// * `T` - The type of the data
158    ///
159    /// # Returns
160    /// * `Result<&mut [T]>` - The data as a mutable slice
161    pub fn data_as_slice_mut<T: AsDLDataType>(&self) -> Result<&mut [T]> {
162        let dtype = T::DL_DATA_TYPE;
163        if self.dtype() != dtype {
164            crate::bail!(
165                crate::error::TYPE_ERROR,
166                "Data type mismatch: expected {}, got {}",
167                dtype.to_string(),
168                self.dtype().to_string()
169            );
170        }
171        if self.device().device_type != DLDeviceType::kDLCPU {
172            crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
173        }
174        crate::ensure!(
175            self.is_contiguous(),
176            crate::error::RUNTIME_ERROR,
177            "Tensor is not contiguous"
178        );
179        unsafe {
180            Ok(std::slice::from_raw_parts_mut(
181                self.data.dltensor.data as *mut T,
182                self.numel(),
183            ))
184        }
185    }
186
187    pub fn shape(&self) -> &[i64] {
188        unsafe { std::slice::from_raw_parts(self.data.dltensor.shape, self.ndim()) }
189    }
190
191    pub fn ndim(&self) -> usize {
192        self.data.dltensor.ndim as usize
193    }
194
195    pub fn numel(&self) -> usize {
196        self.data.dltensor.numel()
197    }
198
199    pub fn strides(&self) -> &[i64] {
200        unsafe { std::slice::from_raw_parts(self.data.dltensor.strides, self.ndim()) }
201    }
202
203    pub fn dtype(&self) -> DLDataType {
204        self.data.dltensor.dtype
205    }
206
207    pub fn device(&self) -> DLDevice {
208        self.data.dltensor.device
209    }
210}
211
212struct TensorObjFromNDAlloc<TNDAlloc>
213where
214    TNDAlloc: NDAllocator,
215{
216    base: TensorObj,
217    alloc: TNDAlloc,
218}
219
220unsafe impl<TNDAlloc: NDAllocator> ObjectCore for TensorObjFromNDAlloc<TNDAlloc> {
221    const TYPE_KEY: &'static str = TensorObj::TYPE_KEY;
222    fn type_index() -> i32 {
223        TensorObj::type_index()
224    }
225    unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
226        TensorObj::object_header_mut(&mut this.base)
227    }
228}
229
230unsafe impl<TNDAlloc: NDAllocator> ObjectCoreWithExtraItems for TensorObjFromNDAlloc<TNDAlloc> {
231    type ExtraItem = i64;
232    fn extra_items_count(this: &Self) -> usize {
233        (this.base.dltensor.ndim * 2) as usize
234    }
235}
236
237impl<TNDAlloc: NDAllocator> Drop for TensorObjFromNDAlloc<TNDAlloc> {
238    fn drop(&mut self) {
239        unsafe {
240            self.alloc.free_data(&self.base.dltensor);
241        }
242    }
243}
244
245impl Tensor {
246    // Create a Tensor from a NDAllocator
247    ///
248    /// # Arguments
249    /// * `alloc` - The NDAllocator
250    /// * `shape` - The shape of the Tensor
251    /// * `dtype` - The data type of the Tensor
252    /// * `device` - The device of the Tensor
253    ///
254    /// # Returns
255    /// * `Tensor` - The created Tensor
256    pub fn from_nd_alloc<TNDAlloc>(
257        alloc: TNDAlloc,
258        shape: &[i64],
259        dtype: DLDataType,
260        device: DLDevice,
261    ) -> Self
262    where
263        TNDAlloc: NDAllocator,
264    {
265        let tensor_obj = TensorObjFromNDAlloc {
266            base: TensorObj {
267                object: Object::new(),
268                dltensor: DLTensor {
269                    data: std::ptr::null_mut(),
270                    device: device,
271                    ndim: shape.len() as i32,
272                    dtype: dtype,
273                    shape: std::ptr::null_mut(),
274                    strides: std::ptr::null_mut(),
275                    byte_offset: 0,
276                },
277            },
278            alloc: alloc,
279        };
280        unsafe {
281            let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj);
282            obj_arc.base.dltensor.shape =
283                TensorObjFromNDAlloc::extra_items(&obj_arc).as_ptr() as *mut i64;
284            obj_arc.base.dltensor.strides = obj_arc.base.dltensor.shape.add(shape.len());
285            let extra_items = TensorObjFromNDAlloc::extra_items_mut(&mut obj_arc);
286            extra_items[..shape.len()].copy_from_slice(shape);
287            Shape::fill_strides_from_shape(shape, &mut extra_items[shape.len()..]);
288            let dltensor_ptr = &obj_arc.base.dltensor as *const DLTensor;
289            obj_arc.base.dltensor.data = obj_arc.alloc.alloc_data(&*dltensor_ptr);
290            Self {
291                data: ObjectArc::from_raw(ObjectArc::into_raw(obj_arc) as *mut TensorObj),
292            }
293        }
294    }
295    /// Create a Tensor from a slice
296    ///
297    /// # Arguments
298    /// * `slice` - The slice to create the Tensor from
299    /// * `shape` - The shape of the Tensor
300    ///
301    /// # Returns
302    /// * `Tensor` - The created Tensor
303    pub fn from_slice<T: AsDLDataType>(slice: &[T], shape: &[i64]) -> Result<Self> {
304        let dtype = T::DL_DATA_TYPE;
305        let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
306        let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
307        if tensor.numel() != slice.len() {
308            crate::bail!(crate::error::VALUE_ERROR, "Slice length mismatch");
309        }
310        tensor.data_as_slice_mut::<T>()?.copy_from_slice(slice);
311        Ok(tensor)
312    }
313}
314
315/// Example CPU NDAllocator
316/// This allocator allocates data on the CPU
317pub struct CPUNDAlloc {}
318
319unsafe impl NDAllocator for CPUNDAlloc {
320    const MIN_ALIGN: usize = 64;
321
322    unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void {
323        let numel = prototype.numel() as usize;
324        let item_size = prototype.item_size();
325        let size = numel * item_size as usize;
326        let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
327        let ptr = std::alloc::alloc(layout);
328        ptr as *mut core::ffi::c_void
329    }
330
331    unsafe fn free_data(&mut self, tensor: &DLTensor) {
332        let numel = tensor.numel() as usize;
333        let item_size = tensor.item_size();
334        let size = numel * item_size;
335        let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
336        std::alloc::dealloc(tensor.data as *mut u8, layout);
337    }
338}