tvm_ffi_sys/
c_api.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19// NOTE: we manually write the C ABI as they are reasonably minimal
20// and we need to ensure clear control of the atomic access etc.
21#![allow(non_camel_case_types)]
22
23use std::ffi::c_void;
24use std::sync::atomic::AtomicU64;
25
26use crate::dlpack::DLDataType;
27use crate::dlpack::DLDevice;
28
29///  The index type of the FFI objects
30#[repr(i32)]
31#[derive(Debug, Copy, Clone, PartialEq, Eq)]
32pub enum TVMFFITypeIndex {
33    /// None/nullptr value
34    kTVMFFINone = 0,
35    /// POD int value
36    kTVMFFIInt = 1,
37    /// POD bool value
38    kTVMFFIBool = 2,
39    /// POD float value
40    kTVMFFIFloat = 3,
41    /// Opaque pointer object
42    kTVMFFIOpaquePtr = 4,
43    /// DLDataType
44    kTVMFFIDataType = 5,
45    /// DLDevice
46    kTVMFFIDevice = 6,
47    /// DLTensor*
48    kTVMFFIDLTensorPtr = 7,
49    /// const char*
50    kTVMFFIRawStr = 8,
51    /// TVMFFIByteArray*
52    kTVMFFIByteArrayPtr = 9,
53    /// R-value reference to ObjectRef
54    kTVMFFIObjectRValueRef = 10,
55    /// Small string on stack
56    kTVMFFISmallStr = 11,
57    /// Small bytes on stack
58    kTVMFFISmallBytes = 12,
59    /// Start of statically defined objects.
60    kTVMFFIStaticObjectBegin = 64,
61    /// String object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
62    kTVMFFIStr = 65,
63    /// Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
64    kTVMFFIBytes = 66,
65    /// Error object.
66    kTVMFFIError = 67,
67    /// Function object.
68    kTVMFFIFunction = 68,
69    /// Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... }
70    kTVMFFIShape = 69,
71    /// Tensor object, layout = { TVMFFIObject, DLTensor, ... }
72    kTVMFFITensor = 70,
73    /// Array object.
74    kTVMFFIArray = 71,
75    //----------------------------------------------------------------
76    // more complex objects
77    //----------------------------------------------------------------
78    /// Map object.
79    kTVMFFIMap = 72,
80    /// Runtime dynamic loaded module object.
81    kTVMFFIModule = 73,
82    /// Opaque python object.
83    kTVMFFIOpaquePyObject = 74,
84}
85
86#[repr(i32)]
87#[derive(Debug, Copy, Clone, PartialEq, Eq)]
88pub enum TVMFFIObjectDeleterFlagBitMask {
89    kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0,
90    kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1,
91    kTVMFFIObjectDeleterFlagBitMaskBoth = (1 << 0) | (1 << 1),
92}
93
94/// Handle to Object from C API's pov
95pub type TVMFFIObjectHandle = *mut c_void;
96pub type TVMFFIObjectDeleter = unsafe extern "C" fn(self_ptr: *mut c_void, flags: i32);
97
98// constants for working with combined reference count
99pub const COMBINED_REF_COUNT_MASK_U32: u64 = (1u64 << 32) - 1;
100pub const COMBINED_REF_COUNT_STRONG_ONE: u64 = 1;
101pub const COMBINED_REF_COUNT_WEAK_ONE: u64 = 1u64 << 32;
102pub const COMBINED_REF_COUNT_BOTH_ONE: u64 =
103    COMBINED_REF_COUNT_STRONG_ONE | COMBINED_REF_COUNT_WEAK_ONE;
104
105#[repr(C)]
106pub struct TVMFFIObject {
107    pub combined_ref_count: AtomicU64,
108    pub type_index: i32,
109    pub __padding: u32,
110    pub deleter: Option<TVMFFIObjectDeleter>,
111    // private padding to ensure 8 bytes alignment
112    #[cfg(target_pointer_width = "32")]
113    __padding: u32,
114}
115
116impl TVMFFIObject {
117    pub fn new() -> Self {
118        Self {
119            combined_ref_count: AtomicU64::new(0),
120            type_index: 0,
121            __padding: 0,
122            deleter: None,
123        }
124    }
125}
126
127/// Second union in TVMFFIAny - 8 bytes
128#[repr(C)]
129#[derive(Copy, Clone)]
130pub union TVMFFIAnyDataUnion {
131    /// Integers
132    pub v_int64: i64,
133    /// Floating-point numbers
134    pub v_float64: f64,
135    /// Typeless pointers
136    pub v_ptr: *mut c_void,
137    /// Raw C-string
138    pub v_c_str: *const i8,
139    /// Ref counted objects
140    pub v_obj: *mut TVMFFIObject,
141    /// Data type
142    pub v_dtype: DLDataType,
143    /// Device
144    pub v_device: DLDevice,
145    /// Small string
146    pub v_bytes: [u8; 8],
147    /// uint64 repr mainly used for hashing
148    pub v_uint64: u64,
149}
150
151/// TVM FFI Any value - a union type that can hold various data types
152#[repr(C)]
153#[derive(Copy, Clone)]
154pub struct TVMFFIAny {
155    /// Type index of the object.
156    /// The type index of Object and Any are shared in FFI.
157    pub type_index: i32,
158    /// small string length or zero padding
159    pub small_str_len: u32,
160    /// data union - 8 bytes
161    pub data_union: TVMFFIAnyDataUnion,
162}
163
164impl TVMFFIAny {
165    /// create a new instance of TVMFFIAny that represents None
166    pub fn new() -> Self {
167        Self {
168            type_index: TVMFFITypeIndex::kTVMFFINone as i32,
169            small_str_len: 0,
170            data_union: TVMFFIAnyDataUnion { v_int64: 0 },
171        }
172    }
173}
174
175/// Byte array data structure used by String and Bytes.
176#[repr(C)]
177pub struct TVMFFIByteArray {
178    pub data: *const u8,
179    pub size: usize,
180}
181
182impl TVMFFIByteArray {
183    pub fn new(data: *const u8, size: usize) -> Self {
184        Self { data, size }
185    }
186    /// Convert the TVMFFIByteArray to a str view
187    ///
188    /// # Arguments
189    /// * `self` - The TVMFFIByteArray to convert.
190    ///
191    /// # Returns
192    /// * `&str` - The converted str view.
193    pub fn as_str(&self) -> &str {
194        unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.data, self.size)) }
195    }
196    /// Unsafe function to create a TVMFFIByteArray from a string
197    /// This function is unsafe as it does not check lifetime of the string
198    /// the caller must ensure that the string is valid for the lifetime of the TVMFFIByteArray
199    ///
200    /// # Arguments
201    /// * `data` - The string to create the TVMFFIByteArray from.
202    ///
203    /// # Returns
204    /// * `TVMFFIByteArray` - The created TVMFFIByteArray.
205    pub unsafe fn from_str(data: &str) -> Self {
206        Self {
207            data: data.as_ptr(),
208            size: data.len(),
209        }
210    }
211}
212
213/// Safe call type for function ABI
214pub type TVMFFISafeCallType = unsafe extern "C" fn(
215    handle: *mut c_void,
216    args: *const TVMFFIAny,
217    num_args: i32,
218    result: *mut TVMFFIAny,
219) -> i32;
220
221/// Function cell
222#[repr(C)]
223pub struct TVMFFIFunctionCell {
224    /// A C API compatible call with exception catching.
225    pub safe_call: TVMFFISafeCallType,
226    pub cxx_call: *mut c_void,
227}
228
229unsafe impl Send for TVMFFIFunctionCell {}
230unsafe impl Sync for TVMFFIFunctionCell {}
231
232#[repr(i32)]
233#[derive(Debug, Copy, Clone, PartialEq, Eq)]
234pub enum TVMFFIBacktraceUpdateMode {
235    kTVMFFIBacktraceUpdateModeReplace = 0,
236    kTVMFFIBacktraceUpdateModeAppend = 1,
237}
238
239/// Error cell used in error object following header.
240#[repr(C)]
241pub struct TVMFFIErrorCell {
242    pub kind: TVMFFIByteArray,
243    pub message: TVMFFIByteArray,
244    pub backtrace: TVMFFIByteArray,
245    pub update_backtrace: unsafe extern "C" fn(
246        self_ptr: *mut c_void,
247        backtrace: *const TVMFFIByteArray,
248        update_mode: i32,
249    ),
250}
251
252/// Shape cell used in shape object following header.
253#[repr(C)]
254pub struct TVMFFIShapeCell {
255    pub data: *const i64,
256    pub size: usize,
257}
258
259/// Field getter function pointer type
260pub type TVMFFIFieldGetter =
261    unsafe extern "C" fn(field: *mut c_void, result: *mut TVMFFIAny) -> i32;
262
263/// Field setter function pointer type
264pub type TVMFFIFieldSetter =
265    unsafe extern "C" fn(field: *mut c_void, value: *const TVMFFIAny) -> i32;
266
267/// Information support for optional object reflection
268#[repr(C)]
269pub struct TVMFFIFieldInfo {
270    /// The name of the field
271    pub name: TVMFFIByteArray,
272    /// The docstring about the field
273    pub doc: TVMFFIByteArray,
274    /// The metadata of the field in JSON string
275    pub metadata: TVMFFIByteArray,
276    /// bitmask flags of the field
277    pub flags: i64,
278    /// The size of the field
279    pub size: i64,
280    /// The alignment of the field
281    pub alignment: i64,
282    /// The offset of the field
283    pub offset: i64,
284    /// The getter to access the field
285    pub getter: Option<TVMFFIFieldGetter>,
286    /// The setter to access the field
287    /// The setter is set even if the field is readonly for serialization
288    pub setter: Option<TVMFFIFieldSetter>,
289    /// The default value of the field, this field hold AnyView,
290    /// valid when flags set kTVMFFIFieldFlagBitMaskHasDefault
291    pub default_value: TVMFFIAny,
292    /// Records the static type kind of the field
293    pub field_static_type_index: i32,
294}
295
296/// Object creator function pointer type
297pub type TVMFFIObjectCreator = unsafe extern "C" fn(result: *mut TVMFFIObjectHandle) -> i32;
298
299/// Method information that can appear in reflection table
300#[repr(C)]
301pub struct TVMFFIMethodInfo {
302    /// The name of the field
303    pub name: TVMFFIByteArray,
304    /// The docstring about the method
305    pub doc: TVMFFIByteArray,
306    /// Optional metadata of the method in JSON string
307    pub metadata: TVMFFIByteArray,
308    /// bitmask flags of the method
309    pub flags: i64,
310    /// The method wrapped as ffi::Function, stored as AnyView
311    /// The first argument to the method is always the self for instance methods
312    pub method: TVMFFIAny,
313}
314
315/// Extra information of object type that can be used for reflection
316///
317/// This information is optional and can be used to enable reflection based
318/// creation of the object.
319#[repr(C)]
320pub struct TVMFFITypeMetadata {
321    /// The docstring about the object
322    pub doc: TVMFFIByteArray,
323    /// An optional function that can create a new empty instance of the type
324    pub creator: Option<TVMFFIObjectCreator>,
325    /// Total size of the object struct, if it is fixed and known
326    ///
327    /// This field is set optional and set to 0 if not registered.
328    pub total_size: i32,
329    /// Optional meta-data for structural eq/hash
330    pub structural_eq_hash_kind: i32,
331}
332
333/// Column array that stores extra attributes about types
334///
335/// The attributes stored in a column array that can be looked up by type index.
336/// Note that the TypeAttr behaves like type_traits so column T so not contain
337/// attributes from base classes.
338#[repr(C)]
339pub struct TVMFFITypeAttrColumn {
340    /// The data of the column
341    pub data: *const TVMFFIAny,
342    /// The size of the column
343    pub size: usize,
344}
345
346/// Runtime type information for object type checking
347#[repr(C)]
348pub struct TVMFFITypeInfo {
349    /// The runtime type index
350    /// It can be allocated during runtime if the type is dynamic
351    pub type_index: i32,
352    /// number of parent types in the type hierachy
353    pub type_depth: i32,
354    /// the unique type key to identify the type
355    pub type_key: TVMFFIByteArray,
356    /// `type_acenstors[depth]` stores the type_index of the acenstors at depth level
357    /// To keep things simple, we do not allow multiple inheritance so the
358    /// hieracy stays as a tree
359    pub type_acenstors: *const *const TVMFFITypeInfo,
360    /// Cached hash value of the type key, used for consistent structural hashing
361    pub type_key_hash: u64,
362    /// number of reflection accessible fields
363    pub num_fields: i32,
364    /// number of reflection acccesible methods
365    pub num_methods: i32,
366    /// The reflection field information
367    pub fields: *const TVMFFIFieldInfo,
368    /// The reflection method
369    pub methods: *const TVMFFIMethodInfo,
370    /// The extra information of the type
371    pub metadata: *const TVMFFITypeMetadata,
372}
373
374unsafe extern "C" {
375    pub fn TVMFFITypeKeyToIndex(type_key: *const TVMFFIByteArray, out_tindex: *mut i32) -> i32;
376    pub fn TVMFFIFunctionGetGlobal(
377        name: *const TVMFFIByteArray,
378        out: *mut TVMFFIObjectHandle,
379    ) -> i32;
380    pub fn TVMFFIFunctionSetGlobal(
381        name: *const TVMFFIByteArray,
382        f: TVMFFIObjectHandle,
383        can_override: i32,
384    ) -> i32;
385    pub fn TVMFFIFunctionCreate(
386        self_ptr: *mut c_void,
387        safe_call: TVMFFISafeCallType,
388        deleter: Option<unsafe extern "C" fn(*mut c_void)>,
389        out: *mut TVMFFIObjectHandle,
390    ) -> i32;
391    pub fn TVMFFIAnyViewToOwnedAny(any_view: *const TVMFFIAny, out: *mut TVMFFIAny) -> i32;
392    pub fn TVMFFIFunctionCall(
393        func: TVMFFIObjectHandle,
394        args: *const TVMFFIAny,
395        num_args: i32,
396        result: *mut TVMFFIAny,
397    ) -> i32;
398    pub fn TVMFFIErrorMoveFromRaised(result: *mut TVMFFIObjectHandle);
399    pub fn TVMFFIErrorSetRaised(error: TVMFFIObjectHandle);
400    pub fn TVMFFIErrorSetRaisedFromCStr(kind: *const i8, message: *const i8);
401    pub fn TVMFFIErrorCreate(
402        kind: *const TVMFFIByteArray,
403        message: *const TVMFFIByteArray,
404        backtrace: *const TVMFFIByteArray,
405        out: *mut TVMFFIObjectHandle,
406    ) -> i32;
407    pub fn TVMFFITensorFromDLPack(
408        from: *mut c_void,
409        require_alignment: i32,
410        require_contiguous: i32,
411        out: *mut TVMFFIObjectHandle,
412    ) -> i32;
413    pub fn TVMFFITensorToDLPack(from: TVMFFIObjectHandle, out: *mut *mut c_void) -> i32;
414    pub fn TVMFFITensorFromDLPackVersioned(
415        from: *mut c_void,
416        require_alignment: i32,
417        require_contiguous: i32,
418        out: *mut TVMFFIObjectHandle,
419    ) -> i32;
420    pub fn TVMFFITensorToDLPackVersioned(from: TVMFFIObjectHandle, out: *mut *mut c_void) -> i32;
421    pub fn TVMFFIStringFromByteArray(input: *const TVMFFIByteArray, out: *mut TVMFFIAny) -> i32;
422    pub fn TVMFFIBytesFromByteArray(input: *const TVMFFIByteArray, out: *mut TVMFFIAny) -> i32;
423    pub fn TVMFFIDataTypeFromString(str: *const TVMFFIByteArray, out: *mut DLDataType) -> i32;
424    pub fn TVMFFIDataTypeToString(dtype: *const DLDataType, out: *mut TVMFFIAny) -> i32;
425    pub fn TVMFFITraceback(
426        filename: *const i8,
427        lineno: i32,
428        func: *const i8,
429        cross_ffi_boundary: i32,
430    ) -> *const TVMFFIByteArray;
431    pub fn TVMFFIGetTypeInfo(type_index: i32) -> *const TVMFFITypeInfo;
432    pub fn TVMFFITestingDummyTarget() -> i32;
433}