tvm_ffi/
device.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21use tvm_ffi_sys::dlpack::DLDevice;
22use tvm_ffi_sys::{TVMFFIAny, TVMFFITypeIndex as TypeIndex};
23use tvm_ffi_sys::{TVMFFIEnvGetStream, TVMFFIEnvSetStream, TVMFFIStreamHandle};
24
25/// Get the current stream for a device
26pub fn current_stream(device: &DLDevice) -> TVMFFIStreamHandle {
27    unsafe { TVMFFIEnvGetStream(device.device_type as i32, device.device_id) }
28}
29/// call f with the stream set to the given stream
30pub fn with_stream<T>(
31    device: &DLDevice,
32    stream: TVMFFIStreamHandle,
33    f: impl FnOnce() -> Result<T>,
34) -> Result<T> {
35    let mut prev_stream: TVMFFIStreamHandle = std::ptr::null_mut();
36    unsafe {
37        crate::check_safe_call!(TVMFFIEnvSetStream(
38            device.device_type as i32,
39            device.device_id,
40            stream,
41            &mut prev_stream as *mut TVMFFIStreamHandle
42        ))?;
43    }
44    let result = f()?;
45    unsafe {
46        crate::check_safe_call!(TVMFFIEnvSetStream(
47            device.device_type as i32,
48            device.device_id,
49            prev_stream,
50            std::ptr::null_mut()
51        ))?;
52    }
53    Ok(result)
54}
55
56/// AnyCompatible for DLDevice
57unsafe impl AnyCompatible for DLDevice {
58    fn type_str() -> String {
59        // make it consistent with c++ representation
60        "Device".to_string()
61    }
62
63    unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
64        data.type_index = TypeIndex::kTVMFFIDevice as i32;
65        data.small_str_len = 0;
66        data.data_union.v_uint64 = 0;
67        data.data_union.v_device = *src;
68    }
69
70    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
71        data.type_index = TypeIndex::kTVMFFIDevice as i32;
72        data.small_str_len = 0;
73        data.data_union.v_int64 = 0;
74        data.data_union.v_device = src;
75    }
76
77    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
78        return data.type_index == TypeIndex::kTVMFFIDevice as i32;
79    }
80
81    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
82        data.data_union.v_device
83    }
84
85    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
86        data.data_union.v_device
87    }
88
89    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
90        if data.type_index == TypeIndex::kTVMFFIDevice as i32 {
91            Ok(data.data_union.v_device)
92        } else {
93            Err(())
94        }
95    }
96}