tvm_ffi/collections/
tensor.rs1use crate::collections::shape::Shape;
20use crate::derive::{Object, ObjectRef};
21use crate::dtype::AsDLDataType;
22use crate::dtype::DLDataTypeExt;
23use crate::error::Result;
24use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems};
25use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor};
26use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
27
28pub unsafe trait NDAllocator: 'static {
33 const MIN_ALIGN: usize;
35 unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void;
42
43 unsafe fn free_data(&mut self, tensor: &DLTensor);
50}
51
52pub trait DLTensorExt {
55 fn numel(&self) -> usize;
56 fn item_size(&self) -> usize;
57}
58
59impl DLTensorExt for DLTensor {
60 fn numel(&self) -> usize {
61 unsafe {
62 std::slice::from_raw_parts(self.shape, self.ndim as usize)
63 .iter()
64 .product::<i64>() as usize
65 }
66 }
67
68 fn item_size(&self) -> usize {
69 (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8
70 }
71}
72
73#[repr(C)]
78#[derive(Object)]
79#[type_key = "ffi.Tensor"]
80#[type_index(TypeIndex::kTVMFFITensor)]
81pub struct TensorObj {
82 object: Object,
83 dltensor: DLTensor,
84}
85
86#[repr(C)]
88#[derive(ObjectRef, Clone)]
89pub struct Tensor {
90 data: ObjectArc<TensorObj>,
91}
92
93impl Tensor {
94 pub fn data_ptr(&self) -> *const core::ffi::c_void {
99 self.data.dltensor.data
100 }
101 pub fn data_ptr_mut(&mut self) -> *mut core::ffi::c_void {
106 self.data.dltensor.data
107 }
108 pub fn is_contiguous(&self) -> bool {
113 let strides = self.strides();
114 let shape = self.shape();
115 let mut expected_stride = 1;
116 for i in (0..self.ndim()).rev() {
117 if strides[i] != expected_stride {
118 return false;
119 }
120 expected_stride *= shape[i];
121 }
122 true
123 }
124
125 pub fn data_as_slice<T: AsDLDataType>(&self) -> Result<&[T]> {
126 let dtype = T::DL_DATA_TYPE;
127 if self.dtype() != dtype {
128 crate::bail!(
129 crate::error::TYPE_ERROR,
130 "Data type mismatch {} vs {}",
131 self.dtype().to_string(),
132 dtype.to_string()
133 );
134 }
135 if self.device().device_type != DLDeviceType::kDLCPU {
136 crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
137 }
138 crate::ensure!(
139 self.is_contiguous(),
140 crate::error::RUNTIME_ERROR,
141 "Tensor is not contiguous"
142 );
143
144 unsafe {
145 Ok(std::slice::from_raw_parts(
146 self.data.dltensor.data as *const T,
147 self.numel(),
148 ))
149 }
150 }
151 #[allow(clippy::wrong_self_convention)]
166 pub fn data_as_slice_mut<T: AsDLDataType>(&self) -> Result<&mut [T]> {
167 let dtype = T::DL_DATA_TYPE;
168 if self.dtype() != dtype {
169 crate::bail!(
170 crate::error::TYPE_ERROR,
171 "Data type mismatch: expected {}, got {}",
172 dtype.to_string(),
173 self.dtype().to_string()
174 );
175 }
176 if self.device().device_type != DLDeviceType::kDLCPU {
177 crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
178 }
179 crate::ensure!(
180 self.is_contiguous(),
181 crate::error::RUNTIME_ERROR,
182 "Tensor is not contiguous"
183 );
184 unsafe {
185 Ok(std::slice::from_raw_parts_mut(
186 self.data.dltensor.data as *mut T,
187 self.numel(),
188 ))
189 }
190 }
191
192 pub fn shape(&self) -> &[i64] {
193 unsafe { std::slice::from_raw_parts(self.data.dltensor.shape, self.ndim()) }
194 }
195
196 pub fn ndim(&self) -> usize {
197 self.data.dltensor.ndim as usize
198 }
199
200 pub fn numel(&self) -> usize {
201 self.data.dltensor.numel()
202 }
203
204 pub fn strides(&self) -> &[i64] {
205 unsafe { std::slice::from_raw_parts(self.data.dltensor.strides, self.ndim()) }
206 }
207
208 pub fn dtype(&self) -> DLDataType {
209 self.data.dltensor.dtype
210 }
211
212 pub fn device(&self) -> DLDevice {
213 self.data.dltensor.device
214 }
215}
216
217struct TensorObjFromNDAlloc<TNDAlloc>
218where
219 TNDAlloc: NDAllocator,
220{
221 base: TensorObj,
222 alloc: TNDAlloc,
223}
224
225unsafe impl<TNDAlloc: NDAllocator> ObjectCore for TensorObjFromNDAlloc<TNDAlloc> {
226 const TYPE_KEY: &'static str = TensorObj::TYPE_KEY;
227 fn type_index() -> i32 {
228 TensorObj::type_index()
229 }
230 unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
231 TensorObj::object_header_mut(&mut this.base)
232 }
233}
234
235unsafe impl<TNDAlloc: NDAllocator> ObjectCoreWithExtraItems for TensorObjFromNDAlloc<TNDAlloc> {
236 type ExtraItem = i64;
237 fn extra_items_count(this: &Self) -> usize {
238 (this.base.dltensor.ndim * 2) as usize
239 }
240}
241
242impl<TNDAlloc: NDAllocator> Drop for TensorObjFromNDAlloc<TNDAlloc> {
243 fn drop(&mut self) {
244 unsafe {
245 self.alloc.free_data(&self.base.dltensor);
246 }
247 }
248}
249
250impl Tensor {
251 pub fn from_nd_alloc<TNDAlloc>(
262 alloc: TNDAlloc,
263 shape: &[i64],
264 dtype: DLDataType,
265 device: DLDevice,
266 ) -> Self
267 where
268 TNDAlloc: NDAllocator,
269 {
270 let tensor_obj = TensorObjFromNDAlloc {
271 base: TensorObj {
272 object: Object::new(),
273 dltensor: DLTensor {
274 data: std::ptr::null_mut(),
275 device: device,
276 ndim: shape.len() as i32,
277 dtype: dtype,
278 shape: std::ptr::null_mut(),
279 strides: std::ptr::null_mut(),
280 byte_offset: 0,
281 },
282 },
283 alloc: alloc,
284 };
285 unsafe {
286 let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj);
287 obj_arc.base.dltensor.shape =
288 TensorObjFromNDAlloc::extra_items(&obj_arc).as_ptr() as *mut i64;
289 obj_arc.base.dltensor.strides = obj_arc.base.dltensor.shape.add(shape.len());
290 let extra_items = TensorObjFromNDAlloc::extra_items_mut(&mut obj_arc);
291 extra_items[..shape.len()].copy_from_slice(shape);
292 Shape::fill_strides_from_shape(shape, &mut extra_items[shape.len()..]);
293 let dltensor_ptr = &obj_arc.base.dltensor as *const DLTensor;
294 obj_arc.base.dltensor.data = obj_arc.alloc.alloc_data(&*dltensor_ptr);
295 Self {
296 data: ObjectArc::from_raw(ObjectArc::into_raw(obj_arc) as *mut TensorObj),
297 }
298 }
299 }
300 pub fn from_slice<T: AsDLDataType>(slice: &[T], shape: &[i64]) -> Result<Self> {
309 let dtype = T::DL_DATA_TYPE;
310 let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
311 let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
312 if tensor.numel() != slice.len() {
313 crate::bail!(crate::error::VALUE_ERROR, "Slice length mismatch");
314 }
315 tensor.data_as_slice_mut::<T>()?.copy_from_slice(slice);
316 Ok(tensor)
317 }
318}
319
320pub struct CPUNDAlloc {}
323
324unsafe impl NDAllocator for CPUNDAlloc {
325 const MIN_ALIGN: usize = 64;
326
327 unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void {
328 let numel = prototype.numel() as usize;
329 let item_size = prototype.item_size();
330 let size = numel * item_size as usize;
331 let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
332 let ptr = std::alloc::alloc(layout);
333 ptr as *mut core::ffi::c_void
334 }
335
336 unsafe fn free_data(&mut self, tensor: &DLTensor) {
337 let numel = tensor.numel() as usize;
338 let item_size = tensor.item_size();
339 let size = numel * item_size;
340 let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
341 std::alloc::dealloc(tensor.data as *mut u8, layout);
342 }
343}