Skip to main content

tvm_ffi/
macros.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19// rexport paste under macro namespace so downstream do not need to specify dep
20pub use paste;
21// ----------------------------------------------------------------------------
22// Macros for error handling
23// ----------------------------------------------------------------------------
24
25/// Macro gto get the name of the function
26///
27/// # Usage
28/// Usage: function_name!()
29#[macro_export]
30macro_rules! function_name {
31    () => {{
32        // dummy function to get the name of the function
33        fn f() {}
34        fn type_name_of<T>(_: T) -> &'static str {
35            std::any::type_name::<T>()
36        }
37        let name = type_name_of(f);
38        // remove the f() from the name
39        &name[..name.len() - 3]
40    }};
41}
42
43/// Check the return code of the safe call
44///
45/// # Arguments
46/// * `ret_code` - The return code of the safe call
47///
48/// # Returns
49/// * `Result<(), Error>` - The result of the safe call
50/// Macro to check safe calls and automatically update traceback with file/line info
51///
52/// Usage: check_safe_call!(function(args))?;
53#[macro_export]
54macro_rules! check_safe_call {
55    ($expr:expr) => {{
56        let ret_code = $expr;
57        if ret_code == 0 {
58            Ok(())
59        } else {
60            let error = $crate::error::Error::from_raised();
61            Err(error)
62        }
63    }};
64}
65
66/// Create a new error with file/line info attached
67///
68/// This macro automatically appends file/line info to the traceback
69///
70/// # Arguments
71/// * `error_kind` - The kind of the error
72/// * `msg` - The message of the error
73/// * `args` - The posisble format arguments
74///
75/// # Returns
76/// * `Result<(), Error>` - The result of the safe call
77#[macro_export]
78macro_rules! bail {
79    ($error_kind:expr, $fmt:expr $(, $args:expr)* $(,)?) => {{
80        let context = format!(
81            "  File \"{}\", line {}, in {}\n",
82            file!(),
83            line!(),
84            $crate::function_name!()
85        );
86        return Err($crate::error::Error::new($error_kind, &format!($fmt $(, $args)*), &context));
87    }};
88}
89
90/// Create a new error with file/line info attached
91///
92/// This macro automatically appends file/line info to the traceback
93///
94/// # Arguments
95/// * `kind` - The kind of the error
96/// * `msg` - The message of the error
97/// * `args` - The posisble format arguments
98///
99/// # Returns
100/// * `Result<(), Error>` - The result of the safe call
101#[macro_export]
102macro_rules! ensure {
103    ($cond:expr, $error_kind:expr, $fmt:expr $(, $args:expr)* $(,)?) => {{
104        if !$cond {
105            $crate::bail!($error_kind, $fmt $(, $args)*);
106        }
107    }};
108}
109
110/// Attach a context to a result if it is error
111///
112/// This macro automatically appends file/line info to the traceback
113///
114/// # Arguments
115/// * `error` - The error to attach the context to
116/// * `msg` - The message of the error
117///
118/// # Returns
119/// * `Result<(), Error>` - The result of the safe call
120#[macro_export]
121macro_rules! attach_context {
122    ($error:expr) => {{
123        match $error {
124            Ok(value) => Ok(value),
125            Err(error) => {
126                let context = format!(
127                    "  File \"{}\", line {}, in {}\n",
128                    file!(),
129                    line!(),
130                    $crate::function_name!()
131                );
132                Err(Error::with_appended_backtrace(error, &context))
133            }
134        }
135    }};
136}
137
138// ----------------------------------------------------------------------------
139// Macros for any definitions
140// ----------------------------------------------------------------------------
141
142// implements try from any for all integer types
143/// Macro to implement `TryFrom<AnyView>` and `TryFrom<Any>` for a list of types
144#[macro_export]
145macro_rules! impl_try_from_any {
146    ($($t:ty),* $(,)?) => {
147        $(
148            impl<'a> TryFrom<$crate::any::AnyView<'a>> for $t {
149                type Error = $crate::error::Error;
150                #[inline(always)]
151                fn try_from(
152                    value: $crate::any::AnyView<'a>
153                ) -> Result<Self, Self::Error> {
154                    type TryFromTemp = $crate::any::TryFromTemp<$t>;
155                    return TryFromTemp::try_from(value).map(TryFromTemp::into_value);
156                }
157            }
158
159            impl TryFrom<$crate::any::Any> for $t {
160                type Error = $crate::error::Error;
161                #[inline(always)]
162                fn try_from(
163                    value: $crate::any::Any
164                ) -> Result<Self, Self::Error> {
165                    type TryFromTemp = $crate::any::TryFromTemp<$t>;
166                    return TryFromTemp::try_from(value).map(TryFromTemp::into_value);
167                }
168            }
169        )*
170    };
171}
172
173/// Macro to implement `TryFrom<AnyView>` and `TryFrom<Any>` for generic types like `Option<T>`
174#[macro_export]
175macro_rules! impl_try_from_any_for_parametric {
176    ($generic_type:ident<$param:ident>) => {
177        impl<'a, $param: AnyCompatible> TryFrom<$crate::any::AnyView<'a>>
178            for $generic_type<$param>
179        {
180            type Error = $crate::error::Error;
181            #[inline(always)]
182            fn try_from(value: $crate::any::AnyView<'a>) -> Result<Self, Self::Error> {
183                type TryFromTemp<T> = $crate::any::TryFromTemp<$generic_type<$param>>;
184                return TryFromTemp::<T>::try_from(value).map(TryFromTemp::<T>::into_value);
185            }
186        }
187
188        impl<$param: AnyCompatible> TryFrom<$crate::any::Any> for $generic_type<$param> {
189            type Error = $crate::error::Error;
190            #[inline(always)]
191            fn try_from(value: $crate::any::Any) -> Result<Self, Self::Error> {
192                type TryFromTemp<T> = $crate::any::TryFromTemp<$generic_type<$param>>;
193                return TryFromTemp::<T>::try_from(value).map(TryFromTemp::<T>::into_value);
194            }
195        }
196    };
197}
198
199/// Macro to implement IntoArgHolder for a list of types
200#[macro_export]
201macro_rules! impl_into_arg_holder_default {
202    ($($t:ty),*) => {
203        $(
204            impl $crate::function_internal::IntoArgHolder for $t {
205                type Target = $t;
206                fn into_arg_holder(self) -> Self::Target {
207                    self
208                }
209            }
210            impl<'a> $crate::function_internal::IntoArgHolder for &'a $t {
211                type Target = &'a $t;
212                fn into_arg_holder(self) -> Self::Target {
213                    self
214                }
215            }
216        )*
217    };
218}
219
220/// Macro to implement ArgIntoRef for a list of types
221#[macro_export]
222macro_rules! impl_arg_into_ref {
223    ($($t:ty),*) => {
224        $(
225            impl $crate::function_internal::ArgIntoRef for $t {
226                type Target = $t;
227                fn to_ref(&self) -> &Self::Target {
228                    &self
229                }
230            }
231            impl<'a> $crate::function_internal::ArgIntoRef for &'a $t {
232                type Target = $t;
233                fn to_ref(&self) -> &Self::Target {
234                    &self
235                }
236            }
237        )*
238    }
239}
240
241// ----------------------------------------------------------------------------
242// Macros for function definitions
243// ----------------------------------------------------------------------------
244
245/// Macro to export a typed function as a C symbol that follows the tvm-ffi ABI
246///
247/// # Arguments
248/// * `$name` - The name of the function
249/// * `$func` - The function to export
250///
251/// # Example
252/// ```rust
253/// use tvm_ffi::*;
254///
255/// fn add_one(x: i32) -> Result<i32> { Ok(x + 1) }
256///
257/// tvm_ffi_dll_export_typed_func!(add_one, add_one);
258/// ```
259#[macro_export]
260macro_rules! tvm_ffi_dll_export_typed_func {
261    ($name:ident, $func:expr) => {
262        $crate::macros::paste::paste! {
263            // `#[no_mangle]` is required so the symbol is preserved in a
264            // `cdylib` and matches the `__tvm_ffi_<name>` naming convention
265            // that `ffi.Module.load_from_file.<format>` looks up via
266            // `GetSymbolWithSymbolPrefix`. Without it, the linker strips the
267            // function from the output `.so`.
268            //
269            // Using plain `#[no_mangle]` (rather than `#[unsafe(no_mangle)]`,
270            // which would require rustc >= 1.82) keeps the crate buildable
271            // on older toolchains. Edition-2024 callers will see a
272            // deprecation warning, which is harmless.
273            //
274            // The path-qualified `$crate::tvm_ffi_sys::…` reference (rather
275            // than a bare `tvm_ffi_sys::…`) lets downstream crates use the
276            // macro without having to add `tvm-ffi-sys` to their own
277            // `[dependencies]`.
278            #[no_mangle]
279            pub unsafe extern "C" fn [<__tvm_ffi_ $name>](
280                _handle: *mut std::ffi::c_void,
281                args: *const $crate::tvm_ffi_sys::TVMFFIAny,
282                num_args: i32,
283                result: *mut $crate::tvm_ffi_sys::TVMFFIAny,
284            ) -> i32 {
285                let packed_args =
286                    std::slice::from_raw_parts(args as *const $crate::any::AnyView, num_args as usize);
287                let ret_value = $crate::function_internal::call_packed_callable($func, packed_args);
288                match ret_value {
289                    Ok(value) => {
290                        *result = $crate::any::Any::into_raw_ffi_any(value);
291                        0
292                    }
293                    Err(error) => {
294                        $crate::error::Error::set_raised(&error);
295                        -1
296                    }
297                }
298            }
299        }
300    };
301}
302
303///-----------------------------------------------------------
304/// into_typed_fn
305///
306/// Converts a generic `Function` into a typed function with compile-time
307/// argument count and type checking. This macro provides a convenient way
308/// to create type-safe wrappers around TVM functions.
309///
310/// # Arguments
311/// * `$f` - The function identifier to convert
312/// * `$trait` - The trait type (typically `Fn`)
313/// * `($t0, $t1, ...)` - The argument types
314/// * `$ret_ty` - The return type
315///
316/// # Example
317/// ```rust
318/// use tvm_ffi::*;
319///
320/// let func = Function::from_typed(|x: i32, y: i32| -> Result<i32> { Ok(x + y) });
321/// let typed_func = into_typed_fn!(func, Fn(i32, &i32) -> Result<i32>);
322/// let result = typed_func(10, &20).unwrap(); // Returns 30
323/// assert_eq!(result, 30);
324/// ```
325/// Note that the `into_typed_fn!` macro can specify arguments to be passed either
326/// by reference or by value in the argument list.
327/// We recommend passing by reference for ObjectRef types such as Tensor.
328/// Since the ffi mechanism requires us to pass arguments by reference.
329///
330/// # Supported Argument Counts
331/// This macro supports functions with 0 to 8 arguments.
332///-----------------------------------------------------------
333#[macro_export]
334macro_rules! into_typed_fn {
335    // Case for 0 arguments
336    ($f:expr, $trait:ident() -> $ret_ty:ty) => {{
337        let _f = $f;
338        move || -> $ret_ty { Ok(_f.call_tuple_with_len::<0, _>(())?.try_into()?) }
339    }};
340    // Case for 1 argument
341    ($f:expr, $trait:ident($t0:ty) -> $ret_ty:ty) => {{
342        let _f = $f;
343        move |a0: $t0| -> $ret_ty {
344            use $crate::function_internal::IntoArgHolderTuple;
345            let tuple_args = (a0,).into_arg_holder_tuple();
346            Ok(_f.call_tuple_with_len::<1, _>(tuple_args)?.try_into()?)
347        }
348    }};
349    // Case for 2 arguments
350    ($f:expr, $trait:ident($t0:ty, $t1:ty) -> $ret_ty:ty) => {{
351        let _f = $f;
352        move |a0: $t0, a1: $t1| -> $ret_ty {
353            use $crate::function_internal::IntoArgHolderTuple;
354            let tuple_args = (a0, a1).into_arg_holder_tuple();
355            Ok(_f.call_tuple_with_len::<2, _>(tuple_args)?.try_into()?)
356        }
357    }};
358    // Case for 3 arguments
359    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty) -> $ret_ty:ty) => {{
360        let _f = $f;
361        move |a0: $t0, a1: $t1, a2: $t2| -> $ret_ty {
362            use $crate::function_internal::IntoArgHolderTuple;
363            let tuple_args = (a0, a1, a2).into_arg_holder_tuple();
364            Ok(_f.call_tuple_with_len::<3, _>(tuple_args)?.try_into()?)
365        }
366    }};
367    // Case for 4 arguments
368    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty) -> $ret_ty:ty) => {{
369        let _f = $f;
370        move |a0: $t0, a1: $t1, a2: $t2, a3: $t3| -> $ret_ty {
371            use $crate::function_internal::IntoArgHolderTuple;
372            let tuple_args = (a0, a1, a2, a3).into_arg_holder_tuple();
373            Ok(_f.call_tuple_with_len::<4, _>(tuple_args)?.try_into()?)
374        }
375    }};
376    // Case for 5 arguments
377    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty) -> $ret_ty:ty) => {{
378        let _f = $f;
379        move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4| -> $ret_ty {
380            use $crate::function_internal::IntoArgHolderTuple;
381            let tuple_args = (a0, a1, a2, a3, a4).into_arg_holder_tuple();
382            Ok(_f.call_tuple_with_len::<5, _>(tuple_args)?.try_into()?)
383        }
384    }};
385    // Case for 6 arguments
386    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty) -> $ret_ty:ty) => {{
387        let _f = $f;
388        move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5| -> $ret_ty {
389            use $crate::function_internal::IntoArgHolderTuple;
390            let tuple_args = (a0, a1, a2, a3, a4, a5).into_arg_holder_tuple();
391            Ok(_f.call_tuple_with_len::<6, _>(tuple_args)?.try_into()?)
392        }
393    }};
394    // Case for 7 arguments
395    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty)
396        -> $ret_ty:ty) => {{
397        let _f = $f;
398        move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6| -> $ret_ty {
399            use $crate::function_internal::IntoArgHolderTuple;
400            let tuple_args = (a0, a1, a2, a3, a4, a5, a6).into_arg_holder_tuple();
401            Ok(_f.call_tuple_with_len::<7, _>(tuple_args)?.try_into()?)
402        }
403    }};
404    // Case for 8 arguments
405    ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty)
406        -> $ret_ty:ty) => {{
407        let _f = $f;
408        move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6, a7: $t7| -> $ret_ty {
409            use $crate::function_internal::IntoArgHolderTuple;
410            let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7).into_arg_holder_tuple();
411            Ok(_f.call_tuple_with_len::<8, _>(tuple_args)?.try_into()?)
412        }
413    }};
414}