tvm_ffi/
device.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::error::Result;
20use crate::type_traits::AnyCompatible;
21use tvm_ffi_sys::dlpack::DLDevice;
22use tvm_ffi_sys::{TVMFFIAny, TVMFFITypeIndex as TypeIndex};
23use tvm_ffi_sys::{TVMFFIEnvGetStream, TVMFFIEnvSetStream, TVMFFIStreamHandle};
24
25/// Get the current stream for a device
26pub fn current_stream(device: &DLDevice) -> TVMFFIStreamHandle {
27    unsafe { TVMFFIEnvGetStream(device.device_type as i32, device.device_id) }
28}
29/// Call `f` with the device stream temporarily set to `stream`.
30///
31/// # Safety
32///
33/// `stream` must be a valid stream handle for the given device, or null.
34pub unsafe fn with_stream<T>(
35    device: &DLDevice,
36    stream: TVMFFIStreamHandle,
37    f: impl FnOnce() -> Result<T>,
38) -> Result<T> {
39    let mut prev_stream: TVMFFIStreamHandle = std::ptr::null_mut();
40    unsafe {
41        crate::check_safe_call!(TVMFFIEnvSetStream(
42            device.device_type as i32,
43            device.device_id,
44            stream,
45            &mut prev_stream as *mut TVMFFIStreamHandle
46        ))?;
47    }
48    let result = f()?;
49    unsafe {
50        crate::check_safe_call!(TVMFFIEnvSetStream(
51            device.device_type as i32,
52            device.device_id,
53            prev_stream,
54            std::ptr::null_mut()
55        ))?;
56    }
57    Ok(result)
58}
59
60/// AnyCompatible for DLDevice
61unsafe impl AnyCompatible for DLDevice {
62    fn type_str() -> String {
63        // make it consistent with c++ representation
64        "Device".to_string()
65    }
66
67    unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
68        data.type_index = TypeIndex::kTVMFFIDevice as i32;
69        data.small_str_len = 0;
70        data.data_union.v_uint64 = 0;
71        data.data_union.v_device = *src;
72    }
73
74    unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
75        data.type_index = TypeIndex::kTVMFFIDevice as i32;
76        data.small_str_len = 0;
77        data.data_union.v_int64 = 0;
78        data.data_union.v_device = src;
79    }
80
81    unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
82        return data.type_index == TypeIndex::kTVMFFIDevice as i32;
83    }
84
85    unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
86        data.data_union.v_device
87    }
88
89    unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
90        data.data_union.v_device
91    }
92
93    unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
94        if data.type_index == TypeIndex::kTVMFFIDevice as i32 {
95            Ok(data.data_union.v_device)
96        } else {
97            Err(())
98        }
99    }
100}