tvm_ffi_sys/
c_api.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19// NOTE: we manually write the C ABI as they are reasonably minimal
20// and we need to ensure clear control of the atomic access etc.
21#![allow(non_camel_case_types)]
22
23use std::ffi::c_void;
24use std::sync::atomic::AtomicU64;
25
26use crate::dlpack::DLDataType;
27use crate::dlpack::DLDevice;
28
29///  The index type of the FFI objects
30#[repr(i32)]
31#[derive(Debug, Copy, Clone, PartialEq, Eq)]
32pub enum TVMFFITypeIndex {
33    /// None/nullptr value
34    kTVMFFINone = 0,
35    /// POD int value
36    kTVMFFIInt = 1,
37    /// POD bool value
38    kTVMFFIBool = 2,
39    /// POD float value
40    kTVMFFIFloat = 3,
41    /// Opaque pointer object
42    kTVMFFIOpaquePtr = 4,
43    /// DLDataType
44    kTVMFFIDataType = 5,
45    /// DLDevice
46    kTVMFFIDevice = 6,
47    /// DLTensor*
48    kTVMFFIDLTensorPtr = 7,
49    /// const char*
50    kTVMFFIRawStr = 8,
51    /// TVMFFIByteArray*
52    kTVMFFIByteArrayPtr = 9,
53    /// R-value reference to ObjectRef
54    kTVMFFIObjectRValueRef = 10,
55    /// Small string on stack
56    kTVMFFISmallStr = 11,
57    /// Small bytes on stack
58    kTVMFFISmallBytes = 12,
59    /// Start of statically defined objects.
60    kTVMFFIStaticObjectBegin = 64,
61    /// String object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
62    kTVMFFIStr = 65,
63    /// Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
64    kTVMFFIBytes = 66,
65    /// Error object.
66    kTVMFFIError = 67,
67    /// Function object.
68    kTVMFFIFunction = 68,
69    /// Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... }
70    kTVMFFIShape = 69,
71    /// Tensor object, layout = { TVMFFIObject, DLTensor, ... }
72    kTVMFFITensor = 70,
73    /// Array object.
74    kTVMFFIArray = 71,
75    //----------------------------------------------------------------
76    // more complex objects
77    //----------------------------------------------------------------
78    /// Map object.
79    kTVMFFIMap = 72,
80    /// Runtime dynamic loaded module object.
81    kTVMFFIModule = 73,
82    /// Opaque python object.
83    kTVMFFIOpaquePyObject = 74,
84}
85
86#[repr(i32)]
87#[derive(Debug, Copy, Clone, PartialEq, Eq)]
88pub enum TVMFFIObjectDeleterFlagBitMask {
89    kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0,
90    kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1,
91    kTVMFFIObjectDeleterFlagBitMaskBoth = (1 << 0) | (1 << 1),
92}
93
94/// Handle to Object from C API's pov
95pub type TVMFFIObjectHandle = *mut c_void;
96pub type TVMFFIObjectDeleter = unsafe extern "C" fn(self_ptr: *mut c_void, flags: i32);
97
98// constants for working with combined reference count
99pub const COMBINED_REF_COUNT_MASK_U32: u64 = (1u64 << 32) - 1;
100pub const COMBINED_REF_COUNT_STRONG_ONE: u64 = 1;
101pub const COMBINED_REF_COUNT_WEAK_ONE: u64 = 1u64 << 32;
102pub const COMBINED_REF_COUNT_BOTH_ONE: u64 =
103    COMBINED_REF_COUNT_STRONG_ONE | COMBINED_REF_COUNT_WEAK_ONE;
104
105#[repr(C)]
106pub struct TVMFFIObject {
107    pub combined_ref_count: AtomicU64,
108    pub type_index: i32,
109    pub __padding: u32,
110    pub deleter: Option<TVMFFIObjectDeleter>,
111    // private padding to ensure 8 bytes alignment
112    #[cfg(target_pointer_width = "32")]
113    __padding: u32,
114}
115
116impl TVMFFIObject {
117    pub fn new() -> Self {
118        Self {
119            combined_ref_count: AtomicU64::new(0),
120            type_index: 0,
121            __padding: 0,
122            deleter: None,
123        }
124    }
125}
126
127/// Second union in TVMFFIAny - 8 bytes
128#[repr(C)]
129#[derive(Copy, Clone)]
130pub union TVMFFIAnyDataUnion {
131    /// Integers
132    pub v_int64: i64,
133    /// Floating-point numbers
134    pub v_float64: f64,
135    /// Typeless pointers
136    pub v_ptr: *mut c_void,
137    /// Raw C-string
138    pub v_c_str: *const i8,
139    /// Ref counted objects
140    pub v_obj: *mut TVMFFIObject,
141    /// Data type
142    pub v_dtype: DLDataType,
143    /// Device
144    pub v_device: DLDevice,
145    /// Small string
146    pub v_bytes: [u8; 8],
147    /// uint64 repr mainly used for hashing
148    pub v_uint64: u64,
149}
150
151/// TVM FFI Any value - a union type that can hold various data types
152#[repr(C)]
153#[derive(Copy, Clone)]
154pub struct TVMFFIAny {
155    /// Type index of the object.
156    /// The type index of Object and Any are shared in FFI.
157    pub type_index: i32,
158    /// small string length or zero padding
159    pub small_str_len: u32,
160    /// data union - 8 bytes
161    pub data_union: TVMFFIAnyDataUnion,
162}
163
164impl TVMFFIAny {
165    /// create a new instance of TVMFFIAny that represents None
166    pub fn new() -> Self {
167        Self {
168            type_index: TVMFFITypeIndex::kTVMFFINone as i32,
169            small_str_len: 0,
170            data_union: TVMFFIAnyDataUnion { v_int64: 0 },
171        }
172    }
173}
174
175/// Byte array data structure used by String and Bytes.
176#[repr(C)]
177pub struct TVMFFIByteArray {
178    pub data: *const u8,
179    pub size: usize,
180}
181
182impl TVMFFIByteArray {
183    pub fn new(data: *const u8, size: usize) -> Self {
184        Self { data, size }
185    }
186    /// Convert the TVMFFIByteArray to a str view
187    ///
188    /// # Arguments
189    /// * `self` - The TVMFFIByteArray to convert.
190    ///
191    /// # Returns
192    /// * `&str` - The converted str view.
193    pub fn as_str(&self) -> &str {
194        unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.data, self.size)) }
195    }
196    /// Unsafe function to create a TVMFFIByteArray from a string
197    /// This function is unsafe as it does not check lifetime of the string
198    /// the caller must ensure that the string is valid for the lifetime of the TVMFFIByteArray
199    ///
200    /// # Arguments
201    /// * `data` - The string to create the TVMFFIByteArray from.
202    ///
203    /// # Returns
204    /// * `TVMFFIByteArray` - The created TVMFFIByteArray.
205    pub unsafe fn from_str(data: &str) -> Self {
206        Self {
207            data: data.as_ptr(),
208            size: data.len(),
209        }
210    }
211}
212
213/// Safe call type for function ABI
214pub type TVMFFISafeCallType = unsafe extern "C" fn(
215    handle: *mut c_void,
216    args: *const TVMFFIAny,
217    num_args: i32,
218    result: *mut TVMFFIAny,
219) -> i32;
220
221/// Function cell
222#[repr(C)]
223pub struct TVMFFIFunctionCell {
224    /// A C API compatible call with exception catching.
225    pub safe_call: TVMFFISafeCallType,
226    pub cxx_call: *mut c_void,
227}
228
229unsafe impl Send for TVMFFIFunctionCell {}
230unsafe impl Sync for TVMFFIFunctionCell {}
231
232#[repr(i32)]
233#[derive(Debug, Copy, Clone, PartialEq, Eq)]
234pub enum TVMFFIBacktraceUpdateMode {
235    kTVMFFIBacktraceUpdateModeReplace = 0,
236    kTVMFFIBacktraceUpdateModeAppend = 1,
237}
238
239/// Error cell used in error object following header.
240#[repr(C)]
241pub struct TVMFFIErrorCell {
242    pub kind: TVMFFIByteArray,
243    pub message: TVMFFIByteArray,
244    pub backtrace: TVMFFIByteArray,
245    pub update_backtrace: unsafe extern "C" fn(
246        self_ptr: *mut c_void,
247        backtrace: *const TVMFFIByteArray,
248        update_mode: i32,
249    ),
250}
251
252/// Shape cell used in shape object following header.
253#[repr(C)]
254pub struct TVMFFIShapeCell {
255    pub data: *const i64,
256    pub size: usize,
257}
258
259/// Field getter function pointer type
260pub type TVMFFIFieldGetter =
261    unsafe extern "C" fn(field: *mut c_void, result: *mut TVMFFIAny) -> i32;
262
263/// Field setter function pointer type
264pub type TVMFFIFieldSetter =
265    unsafe extern "C" fn(field: *mut c_void, value: *const TVMFFIAny) -> i32;
266
267/// Information support for optional object reflection
268#[repr(C)]
269pub struct TVMFFIFieldInfo {
270    /// The name of the field
271    pub name: TVMFFIByteArray,
272    /// The docstring about the field
273    pub doc: TVMFFIByteArray,
274    /// The metadata of the field in JSON string
275    pub metadata: TVMFFIByteArray,
276    /// bitmask flags of the field
277    pub flags: i64,
278    /// The size of the field
279    pub size: i64,
280    /// The alignment of the field
281    pub alignment: i64,
282    /// The offset of the field
283    pub offset: i64,
284    /// The getter to access the field
285    pub getter: Option<TVMFFIFieldGetter>,
286    /// The setter to access the field
287    /// The setter is set even if the field is readonly for serialization
288    pub setter: Option<TVMFFIFieldSetter>,
289    /// The default value or factory of the field, this field holds AnyView.
290    /// Valid when flags set kTVMFFIFieldFlagBitMaskHasDefault.
291    /// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set,
292    /// this is a callable factory function () -> Any.
293    pub default_value_or_factory: TVMFFIAny,
294    /// Records the compile-time static type kind of the field.
295    pub field_static_type_index: i32,
296}
297
298/// Object creator function pointer type
299pub type TVMFFIObjectCreator = unsafe extern "C" fn(result: *mut TVMFFIObjectHandle) -> i32;
300
301/// Method information that can appear in reflection table
302#[repr(C)]
303pub struct TVMFFIMethodInfo {
304    /// The name of the field
305    pub name: TVMFFIByteArray,
306    /// The docstring about the method
307    pub doc: TVMFFIByteArray,
308    /// Optional metadata of the method in JSON string
309    pub metadata: TVMFFIByteArray,
310    /// bitmask flags of the method
311    pub flags: i64,
312    /// The method wrapped as ffi::Function, stored as AnyView
313    /// The first argument to the method is always the self for instance methods
314    pub method: TVMFFIAny,
315}
316
317/// Extra information of object type that can be used for reflection
318///
319/// This information is optional and can be used to enable reflection based
320/// creation of the object.
321#[repr(C)]
322pub struct TVMFFITypeMetadata {
323    /// The docstring about the object
324    pub doc: TVMFFIByteArray,
325    /// An optional function that can create a new empty instance of the type
326    pub creator: Option<TVMFFIObjectCreator>,
327    /// Total size of the object struct, if it is fixed and known
328    ///
329    /// This field is set optional and set to 0 if not registered.
330    pub total_size: i32,
331    /// Optional meta-data for structural eq/hash
332    pub structural_eq_hash_kind: i32,
333}
334
335/// Column array that stores extra attributes about types
336///
337/// The attributes stored in a column array that can be looked up by type index.
338/// Note that the TypeAttr behaves like type_traits so column T so not contain
339/// attributes from base classes.
340#[repr(C)]
341pub struct TVMFFITypeAttrColumn {
342    /// The data of the column, indexed by (type_index - begin_index).
343    pub data: *const TVMFFIAny,
344    /// The number of elements in the data array.
345    /// The column covers type indices [begin_index, begin_index + size).
346    pub size: i32,
347    /// The starting type index of the column data.
348    /// Lookup: if begin_index <= type_index < begin_index + size,
349    /// the entry is data[(type_index - begin_index) as usize].
350    pub begin_index: i32,
351}
352
353/// Runtime type information for object type checking
354#[repr(C)]
355pub struct TVMFFITypeInfo {
356    /// The runtime type index
357    /// It can be allocated during runtime if the type is dynamic
358    pub type_index: i32,
359    /// number of parent types in the type hierachy
360    pub type_depth: i32,
361    /// the unique type key to identify the type
362    pub type_key: TVMFFIByteArray,
363    /// `type_acenstors[depth]` stores the type_index of the acenstors at depth level
364    /// To keep things simple, we do not allow multiple inheritance so the
365    /// hieracy stays as a tree
366    pub type_acenstors: *const *const TVMFFITypeInfo,
367    /// Cached hash value of the type key, used for consistent structural hashing
368    pub type_key_hash: u64,
369    /// number of reflection accessible fields
370    pub num_fields: i32,
371    /// number of reflection acccesible methods
372    pub num_methods: i32,
373    /// The reflection field information
374    pub fields: *const TVMFFIFieldInfo,
375    /// The reflection method
376    pub methods: *const TVMFFIMethodInfo,
377    /// The extra information of the type
378    pub metadata: *const TVMFFITypeMetadata,
379}
380
381unsafe extern "C" {
382    pub fn TVMFFITypeKeyToIndex(type_key: *const TVMFFIByteArray, out_tindex: *mut i32) -> i32;
383    pub fn TVMFFIFunctionGetGlobal(
384        name: *const TVMFFIByteArray,
385        out: *mut TVMFFIObjectHandle,
386    ) -> i32;
387    pub fn TVMFFIFunctionSetGlobal(
388        name: *const TVMFFIByteArray,
389        f: TVMFFIObjectHandle,
390        can_override: i32,
391    ) -> i32;
392    pub fn TVMFFIFunctionCreate(
393        self_ptr: *mut c_void,
394        safe_call: TVMFFISafeCallType,
395        deleter: Option<unsafe extern "C" fn(*mut c_void)>,
396        out: *mut TVMFFIObjectHandle,
397    ) -> i32;
398    pub fn TVMFFIAnyViewToOwnedAny(any_view: *const TVMFFIAny, out: *mut TVMFFIAny) -> i32;
399    pub fn TVMFFIFunctionCall(
400        func: TVMFFIObjectHandle,
401        args: *const TVMFFIAny,
402        num_args: i32,
403        result: *mut TVMFFIAny,
404    ) -> i32;
405    pub fn TVMFFIErrorMoveFromRaised(result: *mut TVMFFIObjectHandle);
406    pub fn TVMFFIErrorSetRaised(error: TVMFFIObjectHandle);
407    pub fn TVMFFIErrorSetRaisedFromCStr(kind: *const i8, message: *const i8);
408    pub fn TVMFFIErrorCreate(
409        kind: *const TVMFFIByteArray,
410        message: *const TVMFFIByteArray,
411        backtrace: *const TVMFFIByteArray,
412        out: *mut TVMFFIObjectHandle,
413    ) -> i32;
414    pub fn TVMFFITensorFromDLPack(
415        from: *mut c_void,
416        require_alignment: i32,
417        require_contiguous: i32,
418        out: *mut TVMFFIObjectHandle,
419    ) -> i32;
420    pub fn TVMFFITensorToDLPack(from: TVMFFIObjectHandle, out: *mut *mut c_void) -> i32;
421    pub fn TVMFFITensorFromDLPackVersioned(
422        from: *mut c_void,
423        require_alignment: i32,
424        require_contiguous: i32,
425        out: *mut TVMFFIObjectHandle,
426    ) -> i32;
427    pub fn TVMFFITensorToDLPackVersioned(from: TVMFFIObjectHandle, out: *mut *mut c_void) -> i32;
428    pub fn TVMFFIStringFromByteArray(input: *const TVMFFIByteArray, out: *mut TVMFFIAny) -> i32;
429    pub fn TVMFFIBytesFromByteArray(input: *const TVMFFIByteArray, out: *mut TVMFFIAny) -> i32;
430    pub fn TVMFFIDataTypeFromString(str: *const TVMFFIByteArray, out: *mut DLDataType) -> i32;
431    pub fn TVMFFIDataTypeToString(dtype: *const DLDataType, out: *mut TVMFFIAny) -> i32;
432    pub fn TVMFFITraceback(
433        filename: *const i8,
434        lineno: i32,
435        func: *const i8,
436        cross_ffi_boundary: i32,
437    ) -> *const TVMFFIByteArray;
438    pub fn TVMFFIGetTypeInfo(type_index: i32) -> *const TVMFFITypeInfo;
439    pub fn TVMFFITestingDummyTarget() -> i32;
440}