tvm_ffi/dtype.rs
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21/// Data type handling
22use tvm_ffi_sys::dlpack::{DLDataType, DLDataTypeCode};
23use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
24use tvm_ffi_sys::{TVMFFIAny, TVMFFIByteArray, TVMFFIDataTypeFromString, TVMFFIDataTypeToString};
25
26/// Extra methods for DLDataType
27pub trait DLDataTypeExt: Sized {
28 /// Convert the DLDataType to a string representation
29 ///
30 /// # Returns
31 /// A string representation of the data type (e.g., "int32", "float64", "bool")
32 fn to_string(&self) -> crate::string::String;
33
34 /// Parse a string representation into a DLDataType
35 ///
36 /// # Arguments
37 /// * `dtype_str` - The string representation of the data type to parse
38 ///
39 /// # Returns
40 /// * `Ok(DLDataType)` - Successfully parsed data type
41 /// * `Err(Error)` - Failed to parse the string
42 ///
43 /// # Examples
44 /// ```
45 /// use tvm_ffi::{DLDataType, DLDataTypeExt};
46 ///
47 /// let dtype = DLDataType::try_from_str("int32").unwrap();
48 /// ```
49 fn try_from_str(dtype_str: &str) -> Result<Self>;
50}
51
52impl DLDataTypeExt for DLDataType {
53 fn to_string(&self) -> crate::string::String {
54 unsafe {
55 let mut ffi_any = TVMFFIAny::new();
56 crate::check_safe_call!(TVMFFIDataTypeToString(&*self, &mut ffi_any)).unwrap();
57 crate::any::Any::from_raw_ffi_any(ffi_any)
58 .try_into()
59 .unwrap()
60 }
61 }
62
63 fn try_from_str(dtype_str: &str) -> Result<Self> {
64 let mut dtype = DLDataType {
65 code: DLDataTypeCode::kDLOpaqueHandle as u8,
66 bits: 0,
67 lanes: 0,
68 };
69 unsafe {
70 let dtype_byte_array = TVMFFIByteArray::from_str(dtype_str);
71 crate::check_safe_call!(TVMFFIDataTypeFromString(&dtype_byte_array, &mut dtype))?;
72 }
73 Ok(dtype)
74 }
75}
76
77/// AnyCompatible implementation for DLDataType
78///
79/// This implementation allows DLDataType to be used with the TVM FFI Any system,
80/// enabling type-safe conversion between DLDataType and the generic Any type.
81unsafe impl AnyCompatible for DLDataType {
82 /// Get the type string identifier for DLDataType
83 ///
84 /// # Returns
85 /// The string "DataType" to match the C++ representation
86 fn type_str() -> String {
87 // make it consistent with c++ representation
88 "DataType".to_string()
89 }
90
91 /// Copy a DLDataType to an Any view
92 ///
93 /// # Arguments
94 /// * `src` - The DLDataType to copy from
95 /// * `data` - The Any view to copy to
96 unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
97 data.type_index = TypeIndex::kTVMFFIDataType as i32;
98 data.small_str_len = 0;
99 data.data_union.v_uint64 = 0;
100 data.data_union.v_dtype = *src;
101 }
102
103 /// Move a DLDataType into an Any
104 ///
105 /// # Arguments
106 /// * `src` - The DLDataType to move from
107 /// * `data` - The Any to move into
108 unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
109 data.type_index = TypeIndex::kTVMFFIDataType as i32;
110 data.small_str_len = 0;
111 data.data_union.v_int64 = 0;
112 data.data_union.v_dtype = src;
113 }
114
115 /// Check if an Any contains a DLDataType
116 ///
117 /// # Arguments
118 /// * `data` - The Any to check
119 ///
120 /// # Returns
121 /// `true` if the Any contains a DLDataType, `false` otherwise
122 unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
123 return data.type_index == TypeIndex::kTVMFFIDataType as i32;
124 }
125
126 /// Copy a DLDataType from an Any view (after type check)
127 ///
128 /// # Arguments
129 /// * `data` - The Any view to copy from
130 ///
131 /// # Returns
132 /// The copied DLDataType
133 ///
134 /// # Safety
135 /// The caller must ensure that `data` contains a DLDataType
136 unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
137 data.data_union.v_dtype
138 }
139
140 /// Move a DLDataType from an Any (after type check)
141 ///
142 /// # Arguments
143 /// * `data` - The Any to move from
144 ///
145 /// # Returns
146 /// The moved DLDataType
147 ///
148 /// # Safety
149 /// The caller must ensure that `data` contains a DLDataType
150 unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
151 data.data_union.v_dtype
152 }
153
154 /// Try to cast an Any view to a DLDataType
155 ///
156 /// This method supports both direct DLDataType conversion and string parsing.
157 ///
158 /// # Arguments
159 /// * `data` - The Any view to cast from
160 ///
161 /// # Returns
162 /// * `Ok(DLDataType)` - Successfully cast to DLDataType
163 /// * `Err(())` - Failed to cast (wrong type or invalid string)
164 unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
165 if data.type_index == TypeIndex::kTVMFFIDataType as i32 {
166 Ok(data.data_union.v_dtype)
167 } else if let Ok(string) = crate::string::String::try_cast_from_any_view(data) {
168 DLDataType::try_from_str(string.as_str()).map_err(|_| ())
169 } else {
170 Err(())
171 }
172 }
173}
174
175/// Trait to convert standard data types to DLDataType
176///
177/// This trait provides a way to get the corresponding DLDataType for standard Rust types.
178/// It's implemented for common integer, unsigned integer, and floating-point types.
179pub trait AsDLDataType: Copy {
180 /// The corresponding DLDataType for this type
181 const DL_DATA_TYPE: DLDataType;
182}
183
184/// Macro to implement AsDLDataType for standard types
185///
186/// This macro generates implementations of the AsDLDataType trait for standard Rust types.
187/// It takes the type, the DLPack data type code, and the number of bits.
188macro_rules! impl_as_dl_data_type {
189 ($type: ty, $code: expr, $bits: expr) => {
190 impl AsDLDataType for $type {
191 const DL_DATA_TYPE: DLDataType = DLDataType {
192 code: $code as u8,
193 bits: $bits as u8,
194 lanes: 1,
195 };
196 }
197 };
198}
199
200impl_as_dl_data_type!(i8, DLDataTypeCode::kDLInt, 8);
201impl_as_dl_data_type!(i16, DLDataTypeCode::kDLInt, 16);
202impl_as_dl_data_type!(i32, DLDataTypeCode::kDLInt, 32);
203impl_as_dl_data_type!(i64, DLDataTypeCode::kDLInt, 64);
204impl_as_dl_data_type!(u8, DLDataTypeCode::kDLUInt, 8);
205impl_as_dl_data_type!(u16, DLDataTypeCode::kDLUInt, 16);
206impl_as_dl_data_type!(u32, DLDataTypeCode::kDLUInt, 32);
207impl_as_dl_data_type!(u64, DLDataTypeCode::kDLUInt, 64);
208impl_as_dl_data_type!(f32, DLDataTypeCode::kDLFloat, 32);
209impl_as_dl_data_type!(f64, DLDataTypeCode::kDLFloat, 64);