tvm_ffi_sys/
dlpack.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19// DLPack C ABI declarations
20// NOTE: we manually write the C ABI as they are reasonably minimal
21// and we need to ensure clear control of the atomic access etc.
22#![allow(non_camel_case_types)]
23
24// DLPack related declarations
25#[repr(u32)]
26#[derive(Debug, Copy, Clone, PartialEq, Eq)]
27pub enum DLDeviceType {
28    kDLCPU = 1,
29    kDLCUDA = 2,
30    kDLCUDAHost = 3,
31    kDLOpenCL = 4,
32    kDLVulkan = 7,
33    kDLMetal = 8,
34    kDLVPI = 9,
35    kDLROCM = 10,
36    kDLROCMHost = 11,
37    kDLExtDev = 12,
38    kDLCUDAManaged = 13,
39    kDLOneAPI = 14,
40    kDLWebGPU = 15,
41    kDLHexagon = 16,
42    kDLMAIA = 17,
43    kDLTrn = 18,
44}
45
46#[repr(C)]
47#[derive(Debug, Copy, Clone, PartialEq, Eq)]
48pub struct DLDevice {
49    pub device_type: DLDeviceType,
50    pub device_id: i32,
51}
52
53/// DLPack data type code enum
54#[repr(u8)]
55#[derive(Debug, Copy, Clone, PartialEq, Eq)]
56pub enum DLDataTypeCode {
57    kDLInt = 0,
58    kDLUInt = 1,
59    kDLFloat = 2,
60    kDLBfloat = 4,
61    kDLComplex = 5,
62    kDLOpaqueHandle = 3,
63    kDLBool = 6,
64    kDLFloat8_e3m4 = 7,
65    kDLFloat8_e4m3 = 8,
66    kDLFloat8_e4m3b11fnuz = 9,
67    kDLFloat8_e4m3fn = 10,
68    kDLFloat8_e4m3fnuz = 11,
69    kDLFloat8_e5m2 = 12,
70    kDLFloat8_e5m2fnuz = 13,
71    kDLFloat8_e8m0fnu = 14,
72    kDLFloat6_e2m3fn = 15,
73    kDLFloat6_e3m2fn = 16,
74    kDLFloat4_e2m1fn = 17,
75}
76
77/// DLPack data type struct
78#[repr(C)]
79#[derive(Debug, Copy, Clone, PartialEq, Eq)]
80pub struct DLDataType {
81    pub code: u8,
82    pub bits: u8,
83    pub lanes: u16,
84}
85
86/// DLPack tensor struct - plain C tensor object, does not manage memory
87#[repr(C)]
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct DLTensor {
90    /// The data pointer points to the allocated data
91    pub data: *mut core::ffi::c_void,
92    /// The device of the tensor
93    pub device: DLDevice,
94    /// Number of dimensions
95    pub ndim: i32,
96    /// The data type of the pointer
97    pub dtype: DLDataType,
98    /// The shape of the tensor
99    pub shape: *mut i64,
100    /// Strides of the tensor (in number of elements, not bytes)
101    /// Can be NULL, indicating tensor is compact and row-majored
102    pub strides: *mut i64,
103    /// The offset in bytes to the beginning pointer to data
104    pub byte_offset: u64,
105}
106
107impl DLDevice {
108    pub fn new(device_type: DLDeviceType, device_id: i32) -> Self {
109        Self {
110            device_type: device_type,
111            device_id: device_id,
112        }
113    }
114}
115
116impl DLDataType {
117    pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self {
118        Self {
119            code: code as u8,
120            bits: bits,
121            lanes: lanes,
122        }
123    }
124}