tvm_ffi/
function.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::any::{Any, AnyView};
20use crate::derive::{Object, ObjectRef};
21use crate::error::{Error, Result};
22use crate::function_internal::{AsPackedCallable, TupleAsPackedArgs};
23use crate::object::{Object, ObjectArc, ObjectCore};
24use tvm_ffi_sys::{
25    TVMFFIAny, TVMFFIByteArray, TVMFFIFunctionCell, TVMFFIFunctionCreate, TVMFFIFunctionGetGlobal,
26    TVMFFIFunctionSetGlobal, TVMFFIObjectHandle, TVMFFISafeCallType, TVMFFITypeIndex,
27};
28
29/// function object
30#[repr(C)]
31#[derive(Object)]
32#[type_key = "ffi.Function"]
33#[type_index(TVMFFITypeIndex::kTVMFFIFunction)]
34pub struct FunctionObj {
35    object: Object,
36    cell: TVMFFIFunctionCell,
37}
38
39/// Error reference class
40#[derive(Clone, ObjectRef)]
41pub struct Function {
42    data: ObjectArc<FunctionObj>,
43}
44
45//------------------------------------------------------------------------
46// CallbackFunctionObjImpl
47//------------------------------------------------------------------------
48/// Special helper class to hold a generic callback state as Object
49/// Logically this Impl can be viewed as a FunctionObj
50/// We can create an ObjectArc<CallbackFunctionObjImpl<F>> so the deleter
51/// can correctly delete the entire object including callback part
52/// then we will convert to ObjectArc<FunctionObj> to be used as function
53#[repr(C)]
54struct CallbackFunctionObjImpl<F: Fn(&[AnyView]) -> Result<Any> + 'static> {
55    function: FunctionObj,
56    callback: F,
57}
58
59impl<F: Fn(&[AnyView]) -> Result<Any> + 'static> CallbackFunctionObjImpl<F> {
60    pub fn from_callback(callback: F) -> Self {
61        Self {
62            function: FunctionObj {
63                object: Object::new(),
64                cell: TVMFFIFunctionCell {
65                    // specfic callback for F
66                    safe_call: Self::invoke_callback,
67                    cxx_call: std::ptr::null_mut(),
68                },
69            },
70            callback,
71        }
72    }
73
74    unsafe extern "C" fn invoke_callback(
75        handle: *mut std::ffi::c_void,
76        args: *const TVMFFIAny,
77        num_args: i32,
78        result: *mut TVMFFIAny,
79    ) -> i32 {
80        let this = &*(handle as *mut Self);
81        let packed_args = std::slice::from_raw_parts(args as *const AnyView, num_args as usize);
82        let ret_value = (this.callback)(packed_args);
83        match ret_value {
84            Ok(value) => {
85                *result = Any::into_raw_ffi_any(value);
86                0
87            }
88            Err(error) => {
89                Error::set_raised(&error);
90                -1
91            }
92        }
93    }
94}
95
96unsafe impl<F: Fn(&[AnyView]) -> Result<Any> + 'static> ObjectCore for CallbackFunctionObjImpl<F> {
97    const TYPE_KEY: &'static str = FunctionObj::TYPE_KEY;
98    fn type_index() -> i32 {
99        FunctionObj::type_index()
100    }
101    unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
102        FunctionObj::object_header_mut(&mut this.function)
103    }
104}
105
106impl Function {
107    /// Call the function in packed format.
108    pub fn call_packed(&self, packed_args: &[AnyView]) -> Result<Any> {
109        unsafe {
110            let packed_args_ptr = packed_args.as_ptr() as *const TVMFFIAny;
111            let mut result = Any::new();
112            let ret_code = (self.data.cell.safe_call)(
113                ObjectArc::as_raw(&self.data) as *mut FunctionObj as *mut std::ffi::c_void,
114                packed_args_ptr,
115                packed_args.len() as i32,
116                Any::as_data_ptr(&mut result),
117            );
118            if ret_code == 0 {
119                Ok(result)
120            } else {
121                Err(Error::from_raised())
122            }
123        }
124    }
125
126    pub fn call_tuple<TupleType>(&self, tuple_args: TupleType) -> Result<Any>
127    where
128        TupleType: TupleAsPackedArgs,
129    {
130        // This is a workaround for Rust's requirement that stack allocation size
131        // must be known at compile time for generic types.
132        // While we know args_len is a constant, Rust doesn't allow us to directly
133        // declare [AnyView::new(); args_len] in generic contexts.
134        //
135        // We use a small vector optimization pattern:
136        // 1. First allocate a small stack buffer (stack_args)
137        // 2. If args_len exceeds STACK_LEN, allocate a heap buffer (heap_args)
138        // 3. Use the appropriate buffer based on size
139        //
140        // Since args_len is a compile-time constant, the compiler should optimize
141        // away the unused branch, making this approach efficient.
142        const STACK_LEN: usize = 4;
143        let mut stack_args = [AnyView::new(); STACK_LEN];
144        let mut heap_args = Vec::<AnyView>::new();
145        let args_len = <TupleType as TupleAsPackedArgs>::LEN;
146        // get packed arguments
147        let packed_args: &mut [AnyView] = if args_len <= STACK_LEN {
148            &mut stack_args[..args_len]
149        } else {
150            heap_args.resize(args_len, AnyView::new());
151            &mut heap_args[..args_len]
152        };
153        (&tuple_args).fill_any_view(packed_args);
154        self.call_packed(packed_args)
155    }
156    /// Call function with compile-time known argument count
157    /// This is an optimized version of call_tuple for when the argument count
158    /// is known at compile time, avoiding the small vector optimization overhead.
159    ///
160    /// # Arguments
161    /// * `tuple_args` - The tuple arguments
162    ///
163    /// # Returns
164    /// * `Any` - The result
165    pub fn call_tuple_with_len<const LEN: usize, TupleType>(
166        &self,
167        tuple_args: TupleType,
168    ) -> Result<Any>
169    where
170        TupleType: TupleAsPackedArgs,
171    {
172        let mut packed_args = [AnyView::new(); LEN];
173        (&tuple_args).fill_any_view(&mut packed_args);
174        self.call_packed(&packed_args)
175    }
176    /// Get global function by name
177    /// This function will throw an error if the function is not found.
178    ///
179    /// # Arguments
180    /// * `name` - The name of the function
181    ///
182    /// # Returns
183    /// * `Function` - The global function
184    pub fn get_global(name: &str) -> Result<Function> {
185        unsafe {
186            let name_arg = TVMFFIByteArray::from_str(name);
187            let mut result: TVMFFIObjectHandle = ::std::ptr::null_mut();
188            crate::check_safe_call!(TVMFFIFunctionGetGlobal(&name_arg, &mut result))?;
189            if result.is_null() {
190                crate::bail!(crate::error::RUNTIME_ERROR, "Function {} not found", name);
191            }
192            Ok(Self {
193                data: ObjectArc::<FunctionObj>::from_raw(result as *mut FunctionObj),
194            })
195        }
196    }
197
198    /// Register a function as a global function
199    /// # Arguments
200    /// * `name` - The name of the function
201    /// * `func` - The function to register
202    ///
203    /// # Returns
204    /// * `Result<()>` - The result of the registration
205    pub fn register_global(name: &str, func: Function) -> Result<()> {
206        unsafe {
207            let name_arg = TVMFFIByteArray::from_str(name);
208            let can_override = 0;
209            crate::check_safe_call!(TVMFFIFunctionSetGlobal(
210                &name_arg,
211                ObjectArc::as_raw(&func.data) as *mut FunctionObj as TVMFFIObjectHandle,
212                can_override
213            ))?;
214            Ok(())
215        }
216    }
217    /// Construct a function from a packed function
218    /// # Arguments
219    /// * `func` - The packed function in signature of `Fn(&[AnyView]) -> Result<Any>`
220    ///
221    /// # Returns
222    /// * `Function` - The function
223    pub fn from_packed<F>(func: F) -> Self
224    where
225        F: Fn(&[AnyView]) -> Result<Any> + 'static,
226    {
227        unsafe {
228            let callback_arc = ObjectArc::new(CallbackFunctionObjImpl::from_callback(func));
229            let func_arc = ObjectArc::<FunctionObj>::from_raw(
230                ObjectArc::into_raw(callback_arc) as *mut FunctionObj
231            );
232            Self { data: func_arc }
233        }
234    }
235
236    /// Construct a function from a typed function
237    /// # Arguments
238    /// * `func` - The typed function with function signature of `F(T0, T1, ...) -> Result<O>`
239    ///
240    /// # Returns
241    /// * `Function` - The function
242    pub fn from_typed<F, I, O>(func: F) -> Self
243    where
244        F: AsPackedCallable<I, O> + 'static,
245    {
246        let closure = move |packed_args: &[AnyView]| -> Result<Any> {
247            let ret_value = func.call_packed(packed_args)?;
248            Ok(ret_value)
249        };
250        Self::from_packed(closure)
251    }
252
253    pub fn from_extern_c(
254        handle: *mut std::ffi::c_void,
255        safe_call: TVMFFISafeCallType,
256        deleter: Option<unsafe extern "C" fn(*mut std::ffi::c_void)>,
257    ) -> Self {
258        unsafe {
259            let mut out_handle: TVMFFIObjectHandle = std::ptr::null_mut();
260            crate::check_safe_call!(TVMFFIFunctionCreate(
261                handle,
262                safe_call,
263                deleter,
264                &mut out_handle
265            ))
266            .unwrap();
267            Self {
268                data: ObjectArc::<FunctionObj>::from_raw(out_handle as *mut FunctionObj),
269            }
270        }
271    }
272}