tvm_ffi_sys/dlpack.rs
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19// DLPack C ABI declarations
20// NOTE: we manually write the C ABI as they are reasonably minimal
21// and we need to ensure clear control of the atomic access etc.
22#![allow(non_camel_case_types)]
23
24// DLPack related declarations
25#[repr(u32)]
26#[derive(Debug, Copy, Clone, PartialEq, Eq)]
27pub enum DLDeviceType {
28 kDLCPU = 1,
29 kDLCUDA = 2,
30 kDLCUDAHost = 3,
31 kDLOpenCL = 4,
32 kDLVulkan = 7,
33 kDLMetal = 8,
34 kDLVPI = 9,
35 kDLROCM = 10,
36 kDLROCMHost = 11,
37 kDLExtDev = 12,
38 kDLCUDAManaged = 13,
39 kDLOneAPI = 14,
40 kDLWebGPU = 15,
41 kDLHexagon = 16,
42 kDLMAIA = 17,
43 kDLTrn = 18,
44}
45
46#[repr(C)]
47#[derive(Debug, Copy, Clone, PartialEq, Eq)]
48pub struct DLDevice {
49 pub device_type: DLDeviceType,
50 pub device_id: i32,
51}
52
53/// DLPack data type code enum
54#[repr(u8)]
55#[derive(Debug, Copy, Clone, PartialEq, Eq)]
56pub enum DLDataTypeCode {
57 kDLInt = 0,
58 kDLUInt = 1,
59 kDLFloat = 2,
60 kDLBfloat = 4,
61 kDLComplex = 5,
62 kDLOpaqueHandle = 3,
63 kDLBool = 6,
64 kDLFloat8_e3m4 = 7,
65 kDLFloat8_e4m3 = 8,
66 kDLFloat8_e4m3b11fnuz = 9,
67 kDLFloat8_e4m3fn = 10,
68 kDLFloat8_e4m3fnuz = 11,
69 kDLFloat8_e5m2 = 12,
70 kDLFloat8_e5m2fnuz = 13,
71 kDLFloat8_e8m0fnu = 14,
72 kDLFloat6_e2m3fn = 15,
73 kDLFloat6_e3m2fn = 16,
74 kDLFloat4_e2m1fn = 17,
75}
76
77/// DLPack data type struct
78#[repr(C)]
79#[derive(Debug, Copy, Clone, PartialEq, Eq)]
80pub struct DLDataType {
81 pub code: u8,
82 pub bits: u8,
83 pub lanes: u16,
84}
85
86/// DLPack tensor struct - plain C tensor object, does not manage memory
87#[repr(C)]
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct DLTensor {
90 /// The data pointer points to the allocated data
91 pub data: *mut core::ffi::c_void,
92 /// The device of the tensor
93 pub device: DLDevice,
94 /// Number of dimensions
95 pub ndim: i32,
96 /// The data type of the pointer
97 pub dtype: DLDataType,
98 /// The shape of the tensor
99 pub shape: *mut i64,
100 /// Strides of the tensor (in number of elements, not bytes)
101 /// Can be NULL, indicating tensor is compact and row-majored
102 pub strides: *mut i64,
103 /// The offset in bytes to the beginning pointer to data
104 pub byte_offset: u64,
105}
106
107impl DLDevice {
108 pub fn new(device_type: DLDeviceType, device_id: i32) -> Self {
109 Self {
110 device_type: device_type,
111 device_id: device_id,
112 }
113 }
114}
115
116impl DLDataType {
117 pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self {
118 Self {
119 code: code as u8,
120 bits: bits,
121 lanes: lanes,
122 }
123 }
124}