tvm_ffi/collections/
tensor.rs1use crate::collections::shape::Shape;
20use crate::derive::{Object, ObjectRef};
21use crate::dtype::AsDLDataType;
22use crate::dtype::DLDataTypeExt;
23use crate::error::Result;
24use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems};
25use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor};
26use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
27
28pub unsafe trait NDAllocator: 'static {
33 const MIN_ALIGN: usize;
35 unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void;
42
43 unsafe fn free_data(&mut self, tensor: &DLTensor);
50}
51
52pub trait DLTensorExt {
55 fn numel(&self) -> usize;
56 fn item_size(&self) -> usize;
57}
58
59impl DLTensorExt for DLTensor {
60 fn numel(&self) -> usize {
61 unsafe {
62 std::slice::from_raw_parts(self.shape, self.ndim as usize)
63 .iter()
64 .product::<i64>() as usize
65 }
66 }
67
68 fn item_size(&self) -> usize {
69 (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8
70 }
71}
72
73#[repr(C)]
78#[derive(Object)]
79#[type_key = "ffi.Tensor"]
80#[type_index(TypeIndex::kTVMFFITensor)]
81pub struct TensorObj {
82 object: Object,
83 dltensor: DLTensor,
84}
85
86#[repr(C)]
88#[derive(ObjectRef, Clone)]
89pub struct Tensor {
90 data: ObjectArc<TensorObj>,
91}
92
93impl Tensor {
94 pub fn data_ptr(&self) -> *const core::ffi::c_void {
99 self.data.dltensor.data
100 }
101 pub fn data_ptr_mut(&mut self) -> *mut core::ffi::c_void {
106 self.data.dltensor.data
107 }
108 pub fn is_contiguous(&self) -> bool {
113 let strides = self.strides();
114 let shape = self.shape();
115 let mut expected_stride = 1;
116 for i in (0..self.ndim()).rev() {
117 if strides[i] != expected_stride {
118 return false;
119 }
120 expected_stride *= shape[i];
121 }
122 true
123 }
124
125 pub fn data_as_slice<T: AsDLDataType>(&self) -> Result<&[T]> {
126 let dtype = T::DL_DATA_TYPE;
127 if self.dtype() != dtype {
128 crate::bail!(
129 crate::error::TYPE_ERROR,
130 "Data type mismatch {} vs {}",
131 self.dtype().to_string(),
132 dtype.to_string()
133 );
134 }
135 if self.device().device_type != DLDeviceType::kDLCPU {
136 crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
137 }
138 crate::ensure!(
139 self.is_contiguous(),
140 crate::error::RUNTIME_ERROR,
141 "Tensor is not contiguous"
142 );
143
144 unsafe {
145 Ok(std::slice::from_raw_parts(
146 self.data.dltensor.data as *const T,
147 self.numel(),
148 ))
149 }
150 }
151 pub fn data_as_slice_mut<T: AsDLDataType>(&self) -> Result<&mut [T]> {
162 let dtype = T::DL_DATA_TYPE;
163 if self.dtype() != dtype {
164 crate::bail!(
165 crate::error::TYPE_ERROR,
166 "Data type mismatch: expected {}, got {}",
167 dtype.to_string(),
168 self.dtype().to_string()
169 );
170 }
171 if self.device().device_type != DLDeviceType::kDLCPU {
172 crate::bail!(crate::error::RUNTIME_ERROR, "Tensor is not on CPU");
173 }
174 crate::ensure!(
175 self.is_contiguous(),
176 crate::error::RUNTIME_ERROR,
177 "Tensor is not contiguous"
178 );
179 unsafe {
180 Ok(std::slice::from_raw_parts_mut(
181 self.data.dltensor.data as *mut T,
182 self.numel(),
183 ))
184 }
185 }
186
187 pub fn shape(&self) -> &[i64] {
188 unsafe { std::slice::from_raw_parts(self.data.dltensor.shape, self.ndim()) }
189 }
190
191 pub fn ndim(&self) -> usize {
192 self.data.dltensor.ndim as usize
193 }
194
195 pub fn numel(&self) -> usize {
196 self.data.dltensor.numel()
197 }
198
199 pub fn strides(&self) -> &[i64] {
200 unsafe { std::slice::from_raw_parts(self.data.dltensor.strides, self.ndim()) }
201 }
202
203 pub fn dtype(&self) -> DLDataType {
204 self.data.dltensor.dtype
205 }
206
207 pub fn device(&self) -> DLDevice {
208 self.data.dltensor.device
209 }
210}
211
212struct TensorObjFromNDAlloc<TNDAlloc>
213where
214 TNDAlloc: NDAllocator,
215{
216 base: TensorObj,
217 alloc: TNDAlloc,
218}
219
220unsafe impl<TNDAlloc: NDAllocator> ObjectCore for TensorObjFromNDAlloc<TNDAlloc> {
221 const TYPE_KEY: &'static str = TensorObj::TYPE_KEY;
222 fn type_index() -> i32 {
223 TensorObj::type_index()
224 }
225 unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
226 TensorObj::object_header_mut(&mut this.base)
227 }
228}
229
230unsafe impl<TNDAlloc: NDAllocator> ObjectCoreWithExtraItems for TensorObjFromNDAlloc<TNDAlloc> {
231 type ExtraItem = i64;
232 fn extra_items_count(this: &Self) -> usize {
233 (this.base.dltensor.ndim * 2) as usize
234 }
235}
236
237impl<TNDAlloc: NDAllocator> Drop for TensorObjFromNDAlloc<TNDAlloc> {
238 fn drop(&mut self) {
239 unsafe {
240 self.alloc.free_data(&self.base.dltensor);
241 }
242 }
243}
244
245impl Tensor {
246 pub fn from_nd_alloc<TNDAlloc>(
257 alloc: TNDAlloc,
258 shape: &[i64],
259 dtype: DLDataType,
260 device: DLDevice,
261 ) -> Self
262 where
263 TNDAlloc: NDAllocator,
264 {
265 let tensor_obj = TensorObjFromNDAlloc {
266 base: TensorObj {
267 object: Object::new(),
268 dltensor: DLTensor {
269 data: std::ptr::null_mut(),
270 device: device,
271 ndim: shape.len() as i32,
272 dtype: dtype,
273 shape: std::ptr::null_mut(),
274 strides: std::ptr::null_mut(),
275 byte_offset: 0,
276 },
277 },
278 alloc: alloc,
279 };
280 unsafe {
281 let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj);
282 obj_arc.base.dltensor.shape =
283 TensorObjFromNDAlloc::extra_items(&obj_arc).as_ptr() as *mut i64;
284 obj_arc.base.dltensor.strides = obj_arc.base.dltensor.shape.add(shape.len());
285 let extra_items = TensorObjFromNDAlloc::extra_items_mut(&mut obj_arc);
286 extra_items[..shape.len()].copy_from_slice(shape);
287 Shape::fill_strides_from_shape(shape, &mut extra_items[shape.len()..]);
288 let dltensor_ptr = &obj_arc.base.dltensor as *const DLTensor;
289 obj_arc.base.dltensor.data = obj_arc.alloc.alloc_data(&*dltensor_ptr);
290 Self {
291 data: ObjectArc::from_raw(ObjectArc::into_raw(obj_arc) as *mut TensorObj),
292 }
293 }
294 }
295 pub fn from_slice<T: AsDLDataType>(slice: &[T], shape: &[i64]) -> Result<Self> {
304 let dtype = T::DL_DATA_TYPE;
305 let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
306 let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
307 if tensor.numel() != slice.len() {
308 crate::bail!(crate::error::VALUE_ERROR, "Slice length mismatch");
309 }
310 tensor.data_as_slice_mut::<T>()?.copy_from_slice(slice);
311 Ok(tensor)
312 }
313}
314
315pub struct CPUNDAlloc {}
318
319unsafe impl NDAllocator for CPUNDAlloc {
320 const MIN_ALIGN: usize = 64;
321
322 unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void {
323 let numel = prototype.numel() as usize;
324 let item_size = prototype.item_size();
325 let size = numel * item_size as usize;
326 let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
327 let ptr = std::alloc::alloc(layout);
328 ptr as *mut core::ffi::c_void
329 }
330
331 unsafe fn free_data(&mut self, tensor: &DLTensor) {
332 let numel = tensor.numel() as usize;
333 let item_size = tensor.item_size();
334 let size = numel * item_size;
335 let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap();
336 std::alloc::dealloc(tensor.data as *mut u8, layout);
337 }
338}