tvm_ffi/collections/
tensor.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::collections::shape::Shape;
20use crate::derive::{Object, ObjectRef};
21use crate::dtype::AsDLDataType;
22use crate::dtype::DLDataTypeExt;
23use crate::error::Result;
24use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems};
25use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor};
26use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
27
28//-----------------------------------------------------
29// NDAllocator Trait
30//-----------------------------------------------------
31/// Trait for n-dimensional array allocators
32pub unsafe trait NDAllocator: 'static {
33    /// The minimum alignment of the data allocated by the allocator
34    const MIN_ALIGN: usize;
35    /// Allocate data for the given DLTensor
36    ///
37    /// # Arguments
38    /// * `tensor` - The DLTensor to allocate data for
39    ///
40    /// This method should fill in the data pointer of the DLTensor.
41    unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void;
42
43    /// Free data for the given DLTensor
44    ///
45    /// # Arguments
46    /// * `tensor` - The DLTensor to free data for
47    ///
48    /// This method should free the data pointer of the DLTensor.
49    unsafe fn free_data(&mut self, tensor: &DLTensor);
50}
51
52/// DLTensorExt trait
53/// This trait provides methods to get the number of elements and the item size of a DLTensor
54pub trait DLTensorExt {
55    fn numel(&self) -> usize;
56    fn item_size(&self) -> usize;
57}
58
59impl DLTensorExt for DLTensor {
60    fn numel(&self) -> usize {
61        unsafe {
62            std::slice::from_raw_parts(self.shape, self.ndim as usize)
63                .iter()
64                .product::<i64>() as usize
65        }
66    }
67
68    fn item_size(&self) -> usize {
69        (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8
70    }
71}
72
73//-----------------------------------------------------
74// Shape
75//-----------------------------------------------------
76// ShapeObj for heap-allocated shape
77#[repr(C)]
78#[derive(Object)]
79#[type_key = "ffi.Tensor"]
80#[type_index(TypeIndex::kTVMFFITensor)]
81pub struct TensorObj {
82    object: Object,
83    dltensor: DLTensor,
84}
85
86/// ABI stable owned Shape for ffi
87#[repr(C)]
88#[derive(ObjectRef, Clone)]
89pub struct Tensor {
90    data: ObjectArc<TensorObj>,
91}
92
93impl Tensor {
94    /// Get the data pointer of the Tensor
95    ///
96    /// # Returns
97    /// * `*mut core::ffi::c_void` - The data pointer of the Tensor
98    pub fn data_ptr(&self) -> *const core::ffi::c_void {
99        self.data.dltensor.data
100    }
101    /// Get the data pointer of the Tensor
102    ///
103    /// # Returns
104    /// * `*mut core::ffi::c_void` - The data pointer of the Tensor
105    pub fn data_ptr_mut(&mut self) -> *mut core::ffi::c_void {
106        self.data.dltensor.data
107    }
108    /// Check if the Tensor is contiguous
109    ///
110    /// # Returns
111    /// * `bool` - True if the Tensor is contiguous, false otherwise
112    pub fn is_contiguous(&self) -> bool {
113        let strides = self.strides();
114        let shape = self.shape();
115        let mut expected_stride = 1;
116        for i in (0..self.ndim()).rev() {
117            if strides[i] != expected_stride {
118                return false;
119            }
120            expected_stride *= shape[i];
121        }
122        true
123    }
124
125    pub fn data_as_slice<T: AsDLDataType>(&self) -> Result<&[T]> {
126        let dtype = T::DL_DATA_TYPE;
127        if self.dtype() != dtype {
128            crate::bail!(
129                crate::error::TYPE_ERROR,
130                "Data type mismatch {} vs {}",
131                self.dtype().to_string(),
132                dtype.to_string()
133            );
134        }
135        if self.device().device_type != DLDeviceType::kDLCPU {
136            crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
137        }
138        crate::ensure!(
139            self.is_contiguous(),
140            crate::error::RUNTIME_ERROR,
141            "Tensor is not contiguous"
142        );
143
144        unsafe {
145            Ok(std::slice::from_raw_parts(
146                self.data.dltensor.data as *const T,
147                self.numel(),
148            ))
149        }
150    }
151    /// Returns the tensor data as a mutable slice.
152    ///
153    /// This method takes `&self` rather than `&mut self` by design: like
154    /// `std::fs::File::write`, the *metadata* of a Tensor (shape, dtype,
155    /// device) is governed by Rust's ownership rules, but writing to the
156    /// underlying data buffer (CPU memory or a GPU pointer) is a side-effect
157    /// outside Rust's aliasing model.  Most C/CUDA kernel APIs accept a
158    /// non-mut Tensor and mutate its data content, so requiring `&mut self`
159    /// here would force artificial mutability annotations throughout the
160    /// deep-learning stack with no real safety benefit.
161    ///
162    /// # Safety contract (caller responsibility)
163    /// If the `Tensor` has been cloned (via `ObjectArc`), the caller must
164    /// ensure no other clone is concurrently reading the data.
165    #[allow(clippy::wrong_self_convention)]
166    pub fn data_as_slice_mut<T: AsDLDataType>(&self) -> Result<&mut [T]> {
167        let dtype = T::DL_DATA_TYPE;
168        if self.dtype() != dtype {
169            crate::bail!(
170                crate::error::TYPE_ERROR,
171                "Data type mismatch: expected {}, got {}",
172                dtype.to_string(),
173                self.dtype().to_string()
174            );
175        }
176        if self.device().device_type != DLDeviceType::kDLCPU {
177            crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
178        }
179        crate::ensure!(
180            self.is_contiguous(),
181            crate::error::RUNTIME_ERROR,
182            "Tensor is not contiguous"
183        );
184        unsafe {
185            Ok(std::slice::from_raw_parts_mut(
186                self.data.dltensor.data as *mut T,
187                self.numel(),
188            ))
189        }
190    }
191
192    pub fn shape(&self) -> &[i64] {
193        unsafe { std::slice::from_raw_parts(self.data.dltensor.shape, self.ndim()) }
194    }
195
196    pub fn ndim(&self) -> usize {
197        self.data.dltensor.ndim as usize
198    }
199
200    pub fn numel(&self) -> usize {
201        self.data.dltensor.numel()
202    }
203
204    pub fn strides(&self) -> &[i64] {
205        unsafe { std::slice::from_raw_parts(self.data.dltensor.strides, self.ndim()) }
206    }
207
208    pub fn dtype(&self) -> DLDataType {
209        self.data.dltensor.dtype
210    }
211
212    pub fn device(&self) -> DLDevice {
213        self.data.dltensor.device
214    }
215}
216
217struct TensorObjFromNDAlloc<TNDAlloc>
218where
219    TNDAlloc: NDAllocator,
220{
221    base: TensorObj,
222    alloc: TNDAlloc,
223}
224
225unsafe impl<TNDAlloc: NDAllocator> ObjectCore for TensorObjFromNDAlloc<TNDAlloc> {
226    const TYPE_KEY: &'static str = TensorObj::TYPE_KEY;
227    fn type_index() -> i32 {
228        TensorObj::type_index()
229    }
230    unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
231        TensorObj::object_header_mut(&mut this.base)
232    }
233}
234
235unsafe impl<TNDAlloc: NDAllocator> ObjectCoreWithExtraItems for TensorObjFromNDAlloc<TNDAlloc> {
236    type ExtraItem = i64;
237    fn extra_items_count(this: &Self) -> usize {
238        (this.base.dltensor.ndim * 2) as usize
239    }
240}
241
242impl<TNDAlloc: NDAllocator> Drop for TensorObjFromNDAlloc<TNDAlloc> {
243    fn drop(&mut self) {
244        unsafe {
245            self.alloc.free_data(&self.base.dltensor);
246        }
247    }
248}
249
250impl Tensor {
251    // Create a Tensor from a NDAllocator
252    ///
253    /// # Arguments
254    /// * `alloc` - The NDAllocator
255    /// * `shape` - The shape of the Tensor
256    /// * `dtype` - The data type of the Tensor
257    /// * `device` - The device of the Tensor
258    ///
259    /// # Returns
260    /// * `Tensor` - The created Tensor
261    pub fn from_nd_alloc<TNDAlloc>(
262        alloc: TNDAlloc,
263        shape: &[i64],
264        dtype: DLDataType,
265        device: DLDevice,
266    ) -> Self
267    where
268        TNDAlloc: NDAllocator,
269    {
270        let tensor_obj = TensorObjFromNDAlloc {
271            base: TensorObj {
272                object: Object::new(),
273                dltensor: DLTensor {
274                    data: std::ptr::null_mut(),
275                    device: device,
276                    ndim: shape.len() as i32,
277                    dtype: dtype,
278                    shape: std::ptr::null_mut(),
279                    strides: std::ptr::null_mut(),
280                    byte_offset: 0,
281                },
282            },
283            alloc: alloc,
284        };
285        unsafe {
286            let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj);
287            obj_arc.base.dltensor.shape =
288                TensorObjFromNDAlloc::extra_items(&obj_arc).as_ptr() as *mut i64;
289            obj_arc.base.dltensor.strides = obj_arc.base.dltensor.shape.add(shape.len());
290            let extra_items = TensorObjFromNDAlloc::extra_items_mut(&mut obj_arc);
291            extra_items[..shape.len()].copy_from_slice(shape);
292            Shape::fill_strides_from_shape(shape, &mut extra_items[shape.len()..]);
293            let dltensor_ptr = &obj_arc.base.dltensor as *const DLTensor;
294            obj_arc.base.dltensor.data = obj_arc.alloc.alloc_data(&*dltensor_ptr);
295            Self {
296                data: ObjectArc::from_raw(ObjectArc::into_raw(obj_arc) as *mut TensorObj),
297            }
298        }
299    }
300    /// Create a Tensor from a slice
301    ///
302    /// # Arguments
303    /// * `slice` - The slice to create the Tensor from
304    /// * `shape` - The shape of the Tensor
305    ///
306    /// # Returns
307    /// * `Tensor` - The created Tensor
308    pub fn from_slice<T: AsDLDataType>(slice: &[T], shape: &[i64]) -> Result<Self> {
309        let dtype = T::DL_DATA_TYPE;
310        let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
311        let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
312        if tensor.numel() != slice.len() {
313            crate::bail!(crate::error::VALUE_ERROR, "Slice length mismatch");
314        }
315        tensor.data_as_slice_mut::<T>()?.copy_from_slice(slice);
316        Ok(tensor)
317    }
318}
319
320/// Example CPU NDAllocator
321/// This allocator allocates data on the CPU
322pub struct CPUNDAlloc {}
323
324unsafe impl NDAllocator for CPUNDAlloc {
325    const MIN_ALIGN: usize = 64;
326
327    unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void {
328        let numel = prototype.numel() as usize;
329        let item_size = prototype.item_size();
330        let size = numel * item_size as usize;
331        let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
332        let ptr = std::alloc::alloc(layout);
333        ptr as *mut core::ffi::c_void
334    }
335
336    unsafe fn free_data(&mut self, tensor: &DLTensor) {
337        let numel = tensor.numel() as usize;
338        let item_size = tensor.item_size();
339        let size = numel * item_size;
340        let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
341        std::alloc::dealloc(tensor.data as *mut u8, layout);
342    }
343}