1use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21use tvm_ffi_sys::dlpack::DLDevice;
22use tvm_ffi_sys::{TVMFFIAny, TVMFFITypeIndex as TypeIndex};
23use tvm_ffi_sys::{TVMFFIEnvGetStream, TVMFFIEnvSetStream, TVMFFIStreamHandle};
24
25pub fn current_stream(device: &DLDevice) -> TVMFFIStreamHandle {
27 unsafe { TVMFFIEnvGetStream(device.device_type as i32, device.device_id) }
28}
29pub fn with_stream<T>(
31 device: &DLDevice,
32 stream: TVMFFIStreamHandle,
33 f: impl FnOnce() -> Result<T>,
34) -> Result<T> {
35 let mut prev_stream: TVMFFIStreamHandle = std::ptr::null_mut();
36 unsafe {
37 crate::check_safe_call!(TVMFFIEnvSetStream(
38 device.device_type as i32,
39 device.device_id,
40 stream,
41 &mut prev_stream as *mut TVMFFIStreamHandle
42 ))?;
43 }
44 let result = f()?;
45 unsafe {
46 crate::check_safe_call!(TVMFFIEnvSetStream(
47 device.device_type as i32,
48 device.device_id,
49 prev_stream,
50 std::ptr::null_mut()
51 ))?;
52 }
53 Ok(result)
54}
55
56unsafe impl AnyCompatible for DLDevice {
58 fn type_str() -> String {
59 "Device".to_string()
61 }
62
63 unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
64 data.type_index = TypeIndex::kTVMFFIDevice as i32;
65 data.small_str_len = 0;
66 data.data_union.v_uint64 = 0;
67 data.data_union.v_device = *src;
68 }
69
70 unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
71 data.type_index = TypeIndex::kTVMFFIDevice as i32;
72 data.small_str_len = 0;
73 data.data_union.v_int64 = 0;
74 data.data_union.v_device = src;
75 }
76
77 unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
78 return data.type_index == TypeIndex::kTVMFFIDevice as i32;
79 }
80
81 unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
82 data.data_union.v_device
83 }
84
85 unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
86 data.data_union.v_device
87 }
88
89 unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
90 if data.type_index == TypeIndex::kTVMFFIDevice as i32 {
91 Ok(data.data_union.v_device)
92 } else {
93 Err(())
94 }
95 }
96}