Source code for tvm_ffi.cpp.nvrtc
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""NVRTC (NVIDIA Runtime Compilation) utilities for compiling CUDA source to CUBIN."""
from __future__ import annotations
from typing import Sequence
[docs]
def nvrtc_compile( # noqa: PLR0912, PLR0915
source: str,
*,
name: str = "kernel.cu",
arch: str | None = None,
extra_opts: Sequence[str] | None = None,
) -> bytes:
"""Compile CUDA source code to CUBIN using NVRTC.
This function uses the NVIDIA Runtime Compilation (NVRTC) library to compile
CUDA C++ source code into a CUBIN binary that can be loaded and executed
using the CUDA Driver API.
Parameters
----------
source : str
The CUDA C++ source code to compile.
name : str, optional
The name to use for the source file (for error messages). Default: "kernel.cu"
arch : str, optional
The target GPU architecture (e.g., "sm_75", "sm_80", "sm_89"). If not specified,
attempts to auto-detect from the current GPU.
extra_opts : Sequence[str], optional
Additional compilation options to pass to NVRTC (e.g., ["-I/path/to/include", "-DDEFINE=1"]).
Returns
-------
bytes
The compiled CUBIN binary data.
Raises
------
RuntimeError
If NVRTC compilation fails or CUDA bindings are not available.
Example
-------
.. code-block:: python
from tvm_ffi.cpp import nvrtc
cuda_source = '''
extern "C" __global__ void add_one(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1.0f;
}
}
'''
cubin_bytes = nvrtc.nvrtc_compile(cuda_source)
# Use cubin_bytes with tvm_ffi.cpp.load_inline and embed_cubin parameter
"""
try:
from cuda.bindings import driver, nvrtc # type: ignore[import-not-found] # noqa: PLC0415
except ImportError as e:
raise RuntimeError(
"CUDA bindings not available. Install with: pip install cuda-python"
) from e
# Auto-detect architecture if not specified
if arch is None:
try:
# Initialize CUDA driver API
(result,) = driver.cuInit(0)
if result != driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize CUDA driver: {result}")
# Get current device
result, device = driver.cuCtxGetDevice()
if result != driver.CUresult.CUDA_SUCCESS:
# Try to get device 0 if no context exists
device = 0
# Get compute capability
result, major = driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device
)
if result != driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to get compute capability major: {result}")
result, minor = driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device
)
if result != driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to get compute capability minor: {result}")
arch = f"sm_{major}{minor}"
except Exception as e:
# Fallback to a reasonable default
raise RuntimeError(
f"Failed to auto-detect GPU architecture: {e}. "
"Please specify 'arch' parameter explicitly."
) from e
# Create program
result, prog = nvrtc.nvrtcCreateProgram(str.encode(source), str.encode(name), 0, None, None)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError(f"Failed to create NVRTC program: {result}")
# Compile options
opts = [
b"--gpu-architecture=" + arch.encode(),
b"-default-device",
]
# Add extra options if provided
if extra_opts:
opts.extend([opt.encode() if isinstance(opt, str) else opt for opt in extra_opts])
# Compile
(result,) = nvrtc.nvrtcCompileProgram(prog, len(opts), opts)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Get compilation log
result_log, log_size = nvrtc.nvrtcGetProgramLogSize(prog)
if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS and log_size > 0:
log_buf = b" " * log_size
(result_log,) = nvrtc.nvrtcGetProgramLog(prog, log_buf)
if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS:
error_msg = f"NVRTC compilation failed:\n{log_buf.decode('utf-8')}"
else:
error_msg = f"NVRTC compilation failed (couldn't get log): {result}"
else:
error_msg = f"NVRTC compilation failed: {result}"
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(error_msg)
# Get CUBIN
result, cubin_size = nvrtc.nvrtcGetCUBINSize(prog)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(f"Failed to get CUBIN size from NVRTC: {result}")
cubin_buf = b" " * cubin_size
(result,) = nvrtc.nvrtcGetCUBIN(prog, cubin_buf)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(f"Failed to get CUBIN from NVRTC: {result}")
# Clean up
nvrtc.nvrtcDestroyProgram(prog)
return cubin_buf