Source code for tvm_ffi.stream

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Stream context."""

from ctypes import c_void_p
from typing import Any, Union

from . import core
from ._tensor import device


[docs] class StreamContext: """Represent a stream context in the FFI system. StreamContext helps setup ffi environment stream by python `with` statement. When entering `with` scope, it caches the current environment stream and setup the given new stream. When exiting `with` scope, it recovers the stream to the cached environment stream. Parameters ---------- device : Device The device to which the stream belongs. stream : Union[int, c_void_p] The stream handle. See Also -------- :py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream` """
[docs] def __init__(self, device: core.Device, stream: Union[int, c_void_p]) -> None: """Initialize a stream context with a device and stream handle.""" self.device_type = device.dlpack_device_type() self.device_id = device.index self.stream = stream
def __enter__(self) -> "StreamContext": """Enter the context and set the current stream.""" self.prev_stream = core._env_set_current_stream( self.device_type, self.device_id, self.stream ) return self def __exit__(self, *args: Any) -> None: """Exit the context and restore the previous stream.""" self.prev_stream = core._env_set_current_stream( self.device_type, self.device_id, self.prev_stream )
try: import torch class TorchStreamContext: """Context manager that syncs Torch and FFI stream contexts.""" def __init__(self, context: Any) -> None: """Initialize with an optional Torch stream/graph context wrapper.""" self.torch_context = context def __enter__(self) -> "TorchStreamContext": """Enter both Torch and FFI stream contexts.""" if self.torch_context: self.torch_context.__enter__() current_stream = torch.cuda.current_stream() self.ffi_context = StreamContext( device(str(current_stream.device)), current_stream.cuda_stream ) self.ffi_context.__enter__() return self def __exit__(self, *args: Any) -> None: """Exit both Torch and FFI stream contexts.""" if self.torch_context: self.torch_context.__exit__(*args) self.ffi_context.__exit__(*args) def use_torch_stream(context: Any = None) -> "TorchStreamContext": """Create an FFI stream context with a Torch stream or graph. cuda graph or current stream if `None` provided. Parameters ---------- context : Any = None The wrapped torch stream or cuda graph. Returns ------- context : tvm_ffi.TorchStreamContext The ffi stream context wrapping torch stream context. Examples -------- .. code-block:: python s = torch.cuda.Stream() with tvm_ffi.use_torch_stream(torch.cuda.stream(s)): ... g = torch.cuda.CUDAGraph() with tvm_ffi.use_torch_stream(torch.cuda.graph(g)): ... Note ---- When working with raw cudaStream_t handle, using :py:func:`tvm_ffi.use_raw_stream` instead. """ return TorchStreamContext(context) except ImportError:
[docs] def use_torch_stream(context: Any = None) -> "TorchStreamContext": """Raise an informative error when Torch is unavailable.""" raise ImportError("Cannot import torch")
[docs] def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]) -> StreamContext: """Create a ffi stream context with given device and stream handle. Parameters ---------- device : tvm_ffi.Device The device to which the stream belongs. stream : Union[int, c_void_p] The stream handle. Returns ------- context : tvm_ffi.StreamContext The ffi stream context. Note ---- When working with torch stram or cuda graph, using :py:func:`tvm_ffi.use_torch_stream` instead. """ if not isinstance(stream, (int, c_void_p)): raise ValueError( "use_raw_stream only accepts int or c_void_p as stram input, " "try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph" ) return StreamContext(device, stream)