tvm_ffi/
dtype.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21/// Data type handling
22use tvm_ffi_sys::dlpack::{DLDataType, DLDataTypeCode};
23use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
24use tvm_ffi_sys::{TVMFFIAny, TVMFFIByteArray, TVMFFIDataTypeFromString, TVMFFIDataTypeToString};
25
26/// Extra methods for DLDataType
27pub trait DLDataTypeExt: Sized {
28    /// Convert the DLDataType to a string representation
29    ///
30    /// # Returns
31    /// A string representation of the data type (e.g., "int32", "float64", "bool")
32    fn to_string(&self) -> crate::string::String;
33
34    /// Parse a string representation into a DLDataType
35    ///
36    /// # Arguments
37    /// * `dtype_str` - The string representation of the data type to parse
38    ///
39    /// # Returns
40    /// * `Ok(DLDataType)` - Successfully parsed data type
41    /// * `Err(Error)` - Failed to parse the string
42    ///
43    /// # Examples
44    /// ```
45    /// use tvm_ffi::{DLDataType, DLDataTypeExt};
46    ///
47    /// let dtype = DLDataType::try_from_str("int32").unwrap();
48    /// ```
49    fn try_from_str(dtype_str: &str) -> Result<Self>;
50}
51
52impl DLDataTypeExt for DLDataType {
53    fn to_string(&self) -> crate::string::String {
54        unsafe {
55            let mut ffi_any = TVMFFIAny::new();
56            crate::check_safe_call!(TVMFFIDataTypeToString(&*self, &mut ffi_any)).unwrap();
57            crate::any::Any::from_raw_ffi_any(ffi_any)
58                .try_into()
59                .unwrap()
60        }
61    }
62
63    fn try_from_str(dtype_str: &str) -> Result<Self> {
64        let mut dtype = DLDataType {
65            code: DLDataTypeCode::kDLOpaqueHandle as u8,
66            bits: 0,
67            lanes: 0,
68        };
69        unsafe {
70            let dtype_byte_array = TVMFFIByteArray::from_str(dtype_str);
71            crate::check_safe_call!(TVMFFIDataTypeFromString(&dtype_byte_array, &mut dtype))?;
72        }
73        Ok(dtype)
74    }
75}
76
77/// AnyCompatible implementation for DLDataType
78///
79/// This implementation allows DLDataType to be used with the TVM FFI Any system,
80/// enabling type-safe conversion between DLDataType and the generic Any type.
81unsafe impl AnyCompatible for DLDataType {
82    /// Get the type string identifier for DLDataType
83    ///
84    /// # Returns
85    /// The string "DataType" to match the C++ representation
86    fn type_str() -> String {
87        // make it consistent with c++ representation
88        "DataType".to_string()
89    }
90
91    /// Copy a DLDataType to an Any view
92    ///
93    /// # Arguments
94    /// * `src` - The DLDataType to copy from
95    /// * `data` - The Any view to copy to
96    unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
97        data.type_index = TypeIndex::kTVMFFIDataType as i32;
98        data.small_str_len = 0;
99        data.data_union.v_uint64 = 0;
100        data.data_union.v_dtype = *src;
101    }
102
103    /// Move a DLDataType into an Any
104    ///
105    /// # Arguments
106    /// * `src` - The DLDataType to move from
107    /// * `data` - The Any to move into
108    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
109        data.type_index = TypeIndex::kTVMFFIDataType as i32;
110        data.small_str_len = 0;
111        data.data_union.v_int64 = 0;
112        data.data_union.v_dtype = src;
113    }
114
115    /// Check if an Any contains a DLDataType
116    ///
117    /// # Arguments
118    /// * `data` - The Any to check
119    ///
120    /// # Returns
121    /// `true` if the Any contains a DLDataType, `false` otherwise
122    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
123        return data.type_index == TypeIndex::kTVMFFIDataType as i32;
124    }
125
126    /// Copy a DLDataType from an Any view (after type check)
127    ///
128    /// # Arguments
129    /// * `data` - The Any view to copy from
130    ///
131    /// # Returns
132    /// The copied DLDataType
133    ///
134    /// # Safety
135    /// The caller must ensure that `data` contains a DLDataType
136    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
137        data.data_union.v_dtype
138    }
139
140    /// Move a DLDataType from an Any (after type check)
141    ///
142    /// # Arguments
143    /// * `data` - The Any to move from
144    ///
145    /// # Returns
146    /// The moved DLDataType
147    ///
148    /// # Safety
149    /// The caller must ensure that `data` contains a DLDataType
150    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
151        data.data_union.v_dtype
152    }
153
154    /// Try to cast an Any view to a DLDataType
155    ///
156    /// This method supports both direct DLDataType conversion and string parsing.
157    ///
158    /// # Arguments
159    /// * `data` - The Any view to cast from
160    ///
161    /// # Returns
162    /// * `Ok(DLDataType)` - Successfully cast to DLDataType
163    /// * `Err(())` - Failed to cast (wrong type or invalid string)
164    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
165        if data.type_index == TypeIndex::kTVMFFIDataType as i32 {
166            Ok(data.data_union.v_dtype)
167        } else if let Ok(string) = crate::string::String::try_cast_from_any_view(data) {
168            DLDataType::try_from_str(string.as_str()).map_err(|_| ())
169        } else {
170            Err(())
171        }
172    }
173}
174
175/// Trait to convert standard data types to DLDataType
176///
177/// This trait provides a way to get the corresponding DLDataType for standard Rust types.
178/// It's implemented for common integer, unsigned integer, and floating-point types.
179pub trait AsDLDataType: Copy {
180    /// The corresponding DLDataType for this type
181    const DL_DATA_TYPE: DLDataType;
182}
183
184/// Macro to implement AsDLDataType for standard types
185///
186/// This macro generates implementations of the AsDLDataType trait for standard Rust types.
187/// It takes the type, the DLPack data type code, and the number of bits.
188macro_rules! impl_as_dl_data_type {
189    ($type: ty, $code: expr, $bits: expr) => {
190        impl AsDLDataType for $type {
191            const DL_DATA_TYPE: DLDataType = DLDataType {
192                code: $code as u8,
193                bits: $bits as u8,
194                lanes: 1,
195            };
196        }
197    };
198}
199
200impl_as_dl_data_type!(i8, DLDataTypeCode::kDLInt, 8);
201impl_as_dl_data_type!(i16, DLDataTypeCode::kDLInt, 16);
202impl_as_dl_data_type!(i32, DLDataTypeCode::kDLInt, 32);
203impl_as_dl_data_type!(i64, DLDataTypeCode::kDLInt, 64);
204impl_as_dl_data_type!(u8, DLDataTypeCode::kDLUInt, 8);
205impl_as_dl_data_type!(u16, DLDataTypeCode::kDLUInt, 16);
206impl_as_dl_data_type!(u32, DLDataTypeCode::kDLUInt, 32);
207impl_as_dl_data_type!(u64, DLDataTypeCode::kDLUInt, 64);
208impl_as_dl_data_type!(f32, DLDataTypeCode::kDLFloat, 32);
209impl_as_dl_data_type!(f64, DLDataTypeCode::kDLFloat, 64);