1use crate::any::{Any, AnyView};
20use crate::derive::{Object, ObjectRef};
21use crate::error::{Error, Result};
22use crate::function_internal::{AsPackedCallable, TupleAsPackedArgs};
23use crate::object::{Object, ObjectArc, ObjectCore};
24use tvm_ffi_sys::{
25 TVMFFIAny, TVMFFIByteArray, TVMFFIFunctionCell, TVMFFIFunctionCreate, TVMFFIFunctionGetGlobal,
26 TVMFFIFunctionSetGlobal, TVMFFIObjectHandle, TVMFFISafeCallType, TVMFFITypeIndex,
27};
28
29#[repr(C)]
31#[derive(Object)]
32#[type_key = "ffi.Function"]
33#[type_index(TVMFFITypeIndex::kTVMFFIFunction)]
34pub struct FunctionObj {
35 object: Object,
36 cell: TVMFFIFunctionCell,
37}
38
39#[derive(Clone, ObjectRef)]
41pub struct Function {
42 data: ObjectArc<FunctionObj>,
43}
44
45#[repr(C)]
54struct CallbackFunctionObjImpl<F: Fn(&[AnyView]) -> Result<Any> + 'static> {
55 function: FunctionObj,
56 callback: F,
57}
58
59impl<F: Fn(&[AnyView]) -> Result<Any> + 'static> CallbackFunctionObjImpl<F> {
60 pub fn from_callback(callback: F) -> Self {
61 Self {
62 function: FunctionObj {
63 object: Object::new(),
64 cell: TVMFFIFunctionCell {
65 safe_call: Self::invoke_callback,
67 cxx_call: std::ptr::null_mut(),
68 },
69 },
70 callback,
71 }
72 }
73
74 unsafe extern "C" fn invoke_callback(
75 handle: *mut std::ffi::c_void,
76 args: *const TVMFFIAny,
77 num_args: i32,
78 result: *mut TVMFFIAny,
79 ) -> i32 {
80 let this = &*(handle as *mut Self);
81 let packed_args = std::slice::from_raw_parts(args as *const AnyView, num_args as usize);
82 let ret_value = (this.callback)(packed_args);
83 match ret_value {
84 Ok(value) => {
85 *result = Any::into_raw_ffi_any(value);
86 0
87 }
88 Err(error) => {
89 Error::set_raised(&error);
90 -1
91 }
92 }
93 }
94}
95
96unsafe impl<F: Fn(&[AnyView]) -> Result<Any> + 'static> ObjectCore for CallbackFunctionObjImpl<F> {
97 const TYPE_KEY: &'static str = FunctionObj::TYPE_KEY;
98 fn type_index() -> i32 {
99 FunctionObj::type_index()
100 }
101 unsafe fn object_header_mut(this: &mut Self) -> &mut tvm_ffi_sys::TVMFFIObject {
102 FunctionObj::object_header_mut(&mut this.function)
103 }
104}
105
106impl Function {
107 pub fn call_packed(&self, packed_args: &[AnyView]) -> Result<Any> {
109 unsafe {
110 let packed_args_ptr = packed_args.as_ptr() as *const TVMFFIAny;
111 let mut result = Any::new();
112 let ret_code = (self.data.cell.safe_call)(
113 ObjectArc::as_raw(&self.data) as *mut FunctionObj as *mut std::ffi::c_void,
114 packed_args_ptr,
115 packed_args.len() as i32,
116 Any::as_data_ptr(&mut result),
117 );
118 if ret_code == 0 {
119 Ok(result)
120 } else {
121 Err(Error::from_raised())
122 }
123 }
124 }
125
126 pub fn call_tuple<TupleType>(&self, tuple_args: TupleType) -> Result<Any>
127 where
128 TupleType: TupleAsPackedArgs,
129 {
130 const STACK_LEN: usize = 4;
143 let mut stack_args = [AnyView::new(); STACK_LEN];
144 let mut heap_args = Vec::<AnyView>::new();
145 let args_len = <TupleType as TupleAsPackedArgs>::LEN;
146 let packed_args: &mut [AnyView] = if args_len <= STACK_LEN {
148 &mut stack_args[..args_len]
149 } else {
150 heap_args.resize(args_len, AnyView::new());
151 &mut heap_args[..args_len]
152 };
153 (&tuple_args).fill_any_view(packed_args);
154 self.call_packed(packed_args)
155 }
156 pub fn call_tuple_with_len<const LEN: usize, TupleType>(
166 &self,
167 tuple_args: TupleType,
168 ) -> Result<Any>
169 where
170 TupleType: TupleAsPackedArgs,
171 {
172 let mut packed_args = [AnyView::new(); LEN];
173 (&tuple_args).fill_any_view(&mut packed_args);
174 self.call_packed(&packed_args)
175 }
176 pub fn get_global(name: &str) -> Result<Function> {
185 unsafe {
186 let name_arg = TVMFFIByteArray::from_str(name);
187 let mut result: TVMFFIObjectHandle = ::std::ptr::null_mut();
188 crate::check_safe_call!(TVMFFIFunctionGetGlobal(&name_arg, &mut result))?;
189 if result.is_null() {
190 crate::bail!(crate::error::RUNTIME_ERROR, "Function {} not found", name);
191 }
192 Ok(Self {
193 data: ObjectArc::<FunctionObj>::from_raw(result as *mut FunctionObj),
194 })
195 }
196 }
197
198 pub fn register_global(name: &str, func: Function) -> Result<()> {
206 unsafe {
207 let name_arg = TVMFFIByteArray::from_str(name);
208 let can_override = 0;
209 crate::check_safe_call!(TVMFFIFunctionSetGlobal(
210 &name_arg,
211 ObjectArc::as_raw(&func.data) as *mut FunctionObj as TVMFFIObjectHandle,
212 can_override
213 ))?;
214 Ok(())
215 }
216 }
217 pub fn from_packed<F>(func: F) -> Self
224 where
225 F: Fn(&[AnyView]) -> Result<Any> + 'static,
226 {
227 unsafe {
228 let callback_arc = ObjectArc::new(CallbackFunctionObjImpl::from_callback(func));
229 let func_arc = ObjectArc::<FunctionObj>::from_raw(
230 ObjectArc::into_raw(callback_arc) as *mut FunctionObj
231 );
232 Self { data: func_arc }
233 }
234 }
235
236 pub fn from_typed<F, I, O>(func: F) -> Self
243 where
244 F: AsPackedCallable<I, O> + 'static,
245 {
246 let closure = move |packed_args: &[AnyView]| -> Result<Any> {
247 let ret_value = func.call_packed(packed_args)?;
248 Ok(ret_value)
249 };
250 Self::from_packed(closure)
251 }
252
253 pub fn from_extern_c(
254 handle: *mut std::ffi::c_void,
255 safe_call: TVMFFISafeCallType,
256 deleter: Option<unsafe extern "C" fn(*mut std::ffi::c_void)>,
257 ) -> Self {
258 unsafe {
259 let mut out_handle: TVMFFIObjectHandle = std::ptr::null_mut();
260 crate::check_safe_call!(TVMFFIFunctionCreate(
261 handle,
262 safe_call,
263 deleter,
264 &mut out_handle
265 ))
266 .unwrap();
267 Self {
268 data: ObjectArc::<FunctionObj>::from_raw(out_handle as *mut FunctionObj),
269 }
270 }
271 }
272}