1use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21use tvm_ffi_sys::dlpack::DLDevice;
22use tvm_ffi_sys::{TVMFFIAny, TVMFFITypeIndex as TypeIndex};
23use tvm_ffi_sys::{TVMFFIEnvGetStream, TVMFFIEnvSetStream, TVMFFIStreamHandle};
24
25pub fn current_stream(device: &DLDevice) -> TVMFFIStreamHandle {
27 unsafe { TVMFFIEnvGetStream(device.device_type as i32, device.device_id) }
28}
29pub unsafe fn with_stream<T>(
35 device: &DLDevice,
36 stream: TVMFFIStreamHandle,
37 f: impl FnOnce() -> Result<T>,
38) -> Result<T> {
39 let mut prev_stream: TVMFFIStreamHandle = std::ptr::null_mut();
40 unsafe {
41 crate::check_safe_call!(TVMFFIEnvSetStream(
42 device.device_type as i32,
43 device.device_id,
44 stream,
45 &mut prev_stream as *mut TVMFFIStreamHandle
46 ))?;
47 }
48 let result = f()?;
49 unsafe {
50 crate::check_safe_call!(TVMFFIEnvSetStream(
51 device.device_type as i32,
52 device.device_id,
53 prev_stream,
54 std::ptr::null_mut()
55 ))?;
56 }
57 Ok(result)
58}
59
60unsafe impl AnyCompatible for DLDevice {
62 fn type_str() -> String {
63 "Device".to_string()
65 }
66
67 unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
68 data.type_index = TypeIndex::kTVMFFIDevice as i32;
69 data.small_str_len = 0;
70 data.data_union.v_uint64 = 0;
71 data.data_union.v_device = *src;
72 }
73
74 unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
75 data.type_index = TypeIndex::kTVMFFIDevice as i32;
76 data.small_str_len = 0;
77 data.data_union.v_int64 = 0;
78 data.data_union.v_device = src;
79 }
80
81 unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
82 return data.type_index == TypeIndex::kTVMFFIDevice as i32;
83 }
84
85 unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
86 data.data_union.v_device
87 }
88
89 unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
90 data.data_union.v_device
91 }
92
93 unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
94 if data.type_index == TypeIndex::kTVMFFIDevice as i32 {
95 Ok(data.data_union.v_device)
96 } else {
97 Err(())
98 }
99 }
100}