CUBIN Launcher Guide#
This guide demonstrates how to load and launch CUDA kernels from CUBIN (CUDA Binary) modules using TVM-FFI. The CUBIN launcher enables you to execute pre-compiled or runtime-compiled CUDA kernels efficiently through the CUDA Runtime API or Driver API.
Overview#
TVM-FFI provides utilities for loading and launching CUDA kernels from CUBIN modules. The implementation supports both CUDA Runtime API (default for CUDA >= 12.8) and CUDA Driver API.
Runtime API (CUDA >= 12.8):
cudaLibraryLoadData()- Load CUBIN from memory buffercudaLibraryGetKernel()- Get kernel handle by namecudaLaunchKernel()- Launch kernel with grid/block dimensions
Driver API:
cuLibraryLoadData()- Load CUBIN from memory buffercuLibraryGetKernel()- Get kernel handle by namecuLaunchKernel()- Launch kernel with grid/block dimensions
Customization:
By default, the implementation uses the Runtime API if compiled with CUDA >= 12.8, falling back to the Driver API for older versions. You can force the usage of the Driver API (or Runtime API) by defining the macro TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API (set to 1 for Driver API, 0 for Runtime API) before including the header.
Warning
CMAKE_CUDA_RUNTIME_LIBRARY and Driver API
When using CMake, the default behavior (if CMAKE_CUDA_RUNTIME_LIBRARY is not set) is to link against the CUDA Runtime Library (cudart). TVM-FFI’s CMake utility automatically defaults this variable to Shared if it is undefined. This introduces a dependency on the CUDA runtime version, requiring the system’s driver to be compatible with that runtime version.
If you intend to use the Driver API only (e.g. by setting TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1) to avoid this runtime dependency:
You must explicitly set
CMAKE_CUDA_RUNTIME_LIBRARYtoNonein your CMake configuration to prevent linkingcudart.You must manually link your target against the CUDA Driver library (usually
cudaon Linux/Windows or CUDA::cuda_driver provided by CMake’sFindCUDAToolkit).
This ensures your application relies solely on the widely compatible CUDA Driver API (libcuda.so.1).
The implementation is in tvm/ffi/extra/cuda/cubin_launcher.h and provides:
tvm::ffi::CubinModule: RAII wrapper for loading CUBIN modules from memorytvm::ffi::CubinKernel: Handle for launching CUDA kernels with specified parametersTVM_FFI_EMBED_CUBIN: Macro for embedding CUBIN data at compile time (legacy / object-linking approach)TVM_FFI_EMBED_CUBIN_FROM_BYTES: Macro for embedding CUBIN data from byte arrays (manual embedding approach)TVM_FFI_EMBED_CUBIN_GET_KERNEL: Macro for retrieving kernels from embedded CUBIN
The CUBIN launcher supports:
Loading CUBIN from memory (embedded data or runtime-generated)
Multi-GPU execution using CUDA primary contexts
Kernel parameter management and launch configuration
Integration with NVRTC, Triton, and other CUDA compilation tools
Build Integration:
TVM-FFI provides convenient tools for embedding CUBIN data at build time:
CMake utilities (
cmake/Utils/EmbedCubin.cmake): Functions for compiling CUDA to CUBIN/FATBIN and embedding it into C++ code or linking it.Python utility (
python -m tvm_ffi.utils.embed_cubin): Command-line tool for embedding CUBIN into object files.Python API (
tvm_ffi.cpp.load_inline()): Runtime embedding viaembed_cubinparameter.
Python Usage#
Basic Workflow#
The typical workflow for launching CUBIN kernels from Python involves:
Generate CUBIN: Compile your CUDA kernel to CUBIN format
Define C++ Wrapper: Write C++ code to load and launch the kernel
Load Module: Use
tvm_ffi.cpp.load_inline()withembed_cubinparameterCall Kernel: Invoke the kernel function from Python
Example: NVRTC Compilation#
Here’s a complete example using NVRTC to compile CUDA source at runtime.
Step 1: Compile CUDA source to CUBIN using NVRTC
# Define CUDA kernels
cuda_source = """
extern "C" __global__ void add_one(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1.0f;
}
}
extern "C" __global__ void mul_two(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] * 2.0f;
}
}
"""
# Compile CUDA source to CUBIN using NVRTC
print("Compiling CUDA kernels to CUBIN using NVRTC...")
cubin_bytes = nvrtc.nvrtc_compile(cuda_source, name="kernels.cu")
print(f"Compiled CUBIN: {len(cubin_bytes)} bytes\n")
Step 2: Define C++ wrapper with embedded CUBIN
# Define C++ code inline to launch the CUDA kernels using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
// Embed CUBIN module with name "nvrtc_cubin"
TVM_FFI_EMBED_CUBIN(nvrtc_cubin);
namespace nvrtc_loader {
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "add_one");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "mul_two");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
} // namespace nvrtc_loader
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, nvrtc_loader::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, nvrtc_loader::MulTwo);
"""
print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
mod = cpp.load_inline(
"nvrtc_loader",
cuda_sources=sources,
embed_cubin={"nvrtc_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")
Key Points:
The
embed_cubinparameter is a dictionary mapping CUBIN names to their binary dataCUBIN names in
embed_cubinmust match names inTVM_FFI_EMBED_CUBINUse
cuda_sourcesparameter (instead ofcpp_sources) to automatically link with CUDA librariesThe C++ wrapper handles device management, stream handling, and kernel launching
Example: Using Triton Kernels#
You can compile Triton kernels to CUBIN and launch them through TVM-FFI.
Step 1: Define and compile Triton kernel
# Define the kernel dynamically
@triton.jit
def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024): # noqa
pid = tl.program_id(0)
start = pid * BLOCK
offsets = start + tl.arange(0, BLOCK)
mask = offsets < n
x = tl.load(X_ptr + offsets, mask=mask, other=0.0)
y = x * x
tl.store(Y_ptr + offsets, y, mask=mask)
# Trigger kernel compilation by doing a dummy call
x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
square_kernel[1, 1](x_dummy, y_dummy, 1024)
# Extract compiled CUBIN from the device cache
device_caches = square_kernel.device_caches
device_id = next(iter(device_caches.keys()))
cache_tuple = device_caches[device_id]
compiled_kernel = next(iter(cache_tuple[0].values()))
# Get CUBIN bytes
cubin_bytes = compiled_kernel.kernel
Step 2: Define C++ wrapper to launch the Triton kernel
# Define C++ code inline to load and launch the Triton kernel using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
// Embed CUBIN module with name "triton_cubin"
TVM_FFI_EMBED_CUBIN(triton_cubin);
namespace triton_loader {
void LaunchSquare(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_cubin, "square_kernel");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Sizes must match";
uint32_t n = static_cast<uint32_t>(x.size(0));
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
uint64_t dummy_ptr = 0;
// Workaround for Triton extra params: pass dummy addresses for unused parameters
void* args[] = {&x_ptr, &y_ptr, &n, &dummy_ptr, &dummy_ptr};
// Kernel was compiled with .reqntid 128, not 1024
tvm::ffi::dim3 grid((n + 127) / 128);
tvm::ffi::dim3 block(128);
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
} // namespace triton_loader
TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_square, triton_loader::LaunchSquare);
"""
print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
# Find CUDA include path
mod = cpp.load_inline(
"triton_loader",
cuda_sources=sources,
embed_cubin={"triton_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")
Note
Triton kernels may require extra dummy parameters in the argument list. Check the compiled kernel’s signature to determine the exact parameter count needed.
C++ Usage#
Embedding CUBIN at Compile Time#
The most convenient way to embed CUBIN/FATBIN data in C++ is using the TVM-FFI build utilities. There are three main approaches:
Object Linking (Standard): Use CMake utilities to compile and link the CUBIN data.
Header Inclusion (Portable): Convert CUBIN to a C header file using
bin2c.C++ Embedding (Modern): Use C++23
#embed(or compiler extensions).
Method 1: Object Linking (Standard)
This approach uses CMake utilities to compile and link the CUBIN data. It works across all supported compilers and handles the low-level details of object file generation and symbol naming.
// Embed CUBIN module with name "env"
// This creates the necessary symbols and singleton struct for accessing the embedded CUBIN
TVM_FFI_EMBED_CUBIN(env);
Method 2: Header Inclusion (Portable)
You can use tools like bin2c to generate a header file containing the byte array and include it.
TVM_FFI_EMBED_CUBIN_FROM_BYTES(env, imageBytes);
Method 3: C++ Embedding (Modern)
Using C++23 #embed (or compiler extensions like #embed in Clang/GCC) allows you to include the binary data directly.
constexpr unsigned char image[]{
// clang >= 20 or gcc >= 14
#embed "kernel_fatbin.fatbin"
};
TVM_FFI_EMBED_CUBIN_FROM_BYTES(env, image);
Key Points:
Use
static auto kernelto cache the kernel lookup for efficiencyKernel arguments must be pointers to the actual values (use
&for addresses)tvm::ffi::dim3supports 1D, 2D, or 3D configurations:dim3(x),dim3(x, y),dim3(x, y, z)TVMFFIEnvGetStreamretrieves the correct CUDA stream for the deviceAlways check kernel launch results with
TVM_FFI_CHECK_CUDA_ERROR(which checks CUDA Runtime API or Driver API errors depending on configuration)
Loading CUBIN at Runtime#
You can also load CUBIN modules dynamically from memory:
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/string.h>
#include <cstdint>
#include <memory>
namespace cubin_dynamic {
// Global CUBIN module and kernels (loaded dynamically)
static std::unique_ptr<tvm::ffi::CubinModule> g_cubin_module;
static std::unique_ptr<tvm::ffi::CubinKernel> g_add_one_kernel;
static std::unique_ptr<tvm::ffi::CubinKernel> g_mul_two_kernel;
/*!
* \brief Set CUBIN module from binary data.
* \param cubin CUBIN binary data as Bytes object.
*/
void SetCubin(const tvm::ffi::Bytes& cubin) {
// Load CUBIN module from memory
g_cubin_module = std::make_unique<tvm::ffi::CubinModule>(cubin);
g_add_one_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["add_one_cuda"]);
g_mul_two_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["mul_two_cuda"]);
}
/*!
* \brief Launch add_one_cuda kernel on input tensor.
* \param x Input tensor (float32, 1D)
* \param y Output tensor (float32, 1D, same shape as x)
*/
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
TVM_FFI_CHECK(g_cubin_module != nullptr, RuntimeError)
<< "CUBIN module not loaded. Call set_cubin first.";
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
tvm::ffi::cuda_api::StreamHandle stream = static_cast<tvm::ffi::cuda_api::StreamHandle>(
TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
tvm::ffi::cuda_api::ResultType result = g_add_one_kernel->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
} // namespace cubin_dynamic
Embedding CUBIN with CMake Utilities#
TVM-FFI provides CMake utility functions that simplify the CUBIN embedding process. This is the recommended approach for CMake-based projects.
Using CMake Utilities:
# Step 1: Compile kernel.cu to FATBIN using add_tvm_ffi_fatbin utility or `CUDA_FATBIN_COMPILATION`
set(CMAKE_CUDA_ARCHITECTURES 75;80;86;89;90;100;120)
if (CMAKE_VERSION VERSION_LESS "3.27.0")
add_tvm_ffi_fatbin(kernel_fatbin CUDA src/kernel.cu)
else ()
add_library(kernel_fatbin OBJECT src/kernel.cu)
set_target_properties(kernel_fatbin PROPERTIES CUDA_FATBIN_COMPILATION ON)
endif ()
# Step 2: Build lib_embedded shared library
add_library(lib_embedded SHARED src/lib_embedded.cc)
target_link_libraries(lib_embedded PRIVATE tvm_ffi::header tvm_ffi::shared)
set_target_properties(lib_embedded PROPERTIES POSITION_INDEPENDENT_CODE ON)
# Step 3: Link against CUDA Driver API or Runtime API based on config
if (CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API)
add_compile_definitions(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1)
target_link_libraries(lib_embedded PRIVATE cuda)
else ()
target_link_libraries(lib_embedded PRIVATE CUDA::cudart)
endif ()
# Step 4: Embed CUBIN into shared library just defined, using tvm_ffi_embed_cubin utility This
# creates symbols: __tvm_ffi__cubin_env (local)
tvm_ffi_embed_bin_into(lib_embedded SYMBOL env BIN "$<TARGET_OBJECTS:kernel_fatbin>")
set_target_properties(
lib_embedded
PROPERTIES PREFIX ""
SUFFIX ".so"
LINKER_LANGUAGE CXX
)
Available CMake Functions:
add_tvm_ffi_cubin(<target> CUDA <source>): Creates an object library that compiles CUDA source to CUBIN format. This is a compatibility wrapper; for CMake >= 3.27, you can use standardCUDA_CUBIN_COMPILATIONproperty.add_tvm_ffi_fatbin(<target> CUDA <source>): Creates an object library that compiles CUDA source to FATBIN format. This is a compatibility wrapper; for CMake >= 3.27, you can use standardCUDA_FATBIN_COMPILATIONproperty.tvm_ffi_embed_bin_into(<target> SYMBOL <symbol> BIN <bin_file>): Embeds a CUBIN/FATBIN file into an existing object library target. This works by linking the binary data into the target, allowing access viaTVM_FFI_EMBED_CUBIN(<name>).target: The target to embed into (must be an object library or have object files).symbol: Symbol name to use (must matchTVM_FFI_EMBED_CUBIN(symbol)).BIN: Path to the CUBIN/FATBIN file (e.g., from$<TARGET_OBJECTS:...>).
Note
When including cmake/Utils/EmbedCubin.cmake, if CMAKE_CUDA_RUNTIME_LIBRARY is not set, it defaults to Shared.
This prevents static linking of cudart, which requires an exact driver version match.
If you intend to use the Driver API only (e.g., via TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1),
you should explicitly set CMAKE_CUDA_RUNTIME_LIBRARY to None in your CMake configuration before including this utility to avoid linking against the CUDA runtime library.
And link with CUDA Driver API.
Embedding CUBIN with Python Utility#
For more advanced use cases or non-CMake build systems, you can use the Python command-line utility to embed CUBIN data into existing object files.
Command-Line Usage:
# Step 1: Compile C++ source to object file
g++ -c -fPIC -std=c++17 -I/path/to/tvm-ffi/include mycode.cc -o mycode.o
# Step 2: Embed CUBIN into the object file
python -m tvm_ffi.utils.embed_cubin \
--output-obj mycode_with_cubin.o \
--input-obj mycode.o \
--cubin kernel.cubin \
--name my_kernels
# Step 3: Link into final library
g++ -o mylib.so -shared mycode_with_cubin.o -lcudart
Python API:
from pathlib import Path
from tvm_ffi.utils.embed_cubin import embed_cubin
embed_cubin(
cubin_path=Path("kernel.cubin"),
input_obj_path=Path("mycode.o"),
output_obj_path=Path("mycode_with_cubin.o"),
name="my_kernels",
verbose=True # Optional: print detailed progress
)
The Python utility performs these steps:
Creates intermediate CUBIN object file using
ld -r -b binaryAdds
.note.GNU-stacksection for securityRenames symbols to match TVM-FFI format (
__tvm_ffi__cubin_<name>)Merges with input object file using relocatable linking
Localizes symbols to prevent conflicts when multiple object files use the same name
Manual CUBIN Embedding#
For reference, here’s how to manually embed CUBIN using objcopy and ld:
Step 1: Compile CUDA kernel to CUBIN
nvcc --cubin -arch=sm_75 kernel.cu -o kernel.cubin
Step 2: Convert CUBIN to object file
ld -r -b binary -o kernel_data.o kernel.cubin
Step 3: Rename symbols with objcopy
objcopy --rename-section .data=.rodata,alloc,load,readonly,data,contents \
--redefine-sym _binary_kernel_cubin_start=__tvm_ffi__cubin_my_kernels \
--redefine-sym _binary_kernel_cubin_end=__tvm_ffi__cubin_my_kernels_end \
kernel_data.o
Step 4: Link with your library
g++ -o mylib.so -shared mycode.cc kernel_data.o -Wl,-z,noexecstack -lcudart
The symbol names must match the name used in TVM_FFI_EMBED_CUBIN.
When to Use Each Approach:
CMake utilities: Best for CMake-based projects, provides cleanest integration (recommended)
Python utility: Best for custom build systems, Makefile-based projects, or advanced workflows (recommended)
Manual objcopy: Low-level approach, useful for understanding the process or debugging (only for customized use cases)
Advanced Topics#
Multi-GPU Support#
The CUBIN launcher automatically handles multi-GPU execution through CUDA primary contexts. Kernels will execute on the device associated with the input tensors:
void MultiGPUExample(tvm::ffi::TensorView x_gpu0, tvm::ffi::TensorView x_gpu1) {
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "process");
// Launch on GPU 0 (device determined by x_gpu0.device())
LaunchOnDevice(kernel, x_gpu0);
// Launch on GPU 1 (device determined by x_gpu1.device())
LaunchOnDevice(kernel, x_gpu1);
}
The tvm::ffi::CubinKernel automatically uses the device context from the input tensors.
Kernel Launch Configuration#
When writing the C++ wrapper, important considerations include:
Grid/Block Dimensions: Use
tvm::ffi::dim3for 1D, 2D, or 3D configurations1D:
dim3(x)→(x, 1, 1)2D:
dim3(x, y)→(x, y, 1)3D:
dim3(x, y, z)→(x, y, z)
Kernel Arguments: Must be pointers to actual values
For device pointers:
void* ptr = tensor.data_ptr(); args[] = {&ptr}For scalars:
int n = 42; args[] = {&n}
Stream Management: Use
TVMFFIEnvGetStreamto get the correct CUDA stream for synchronization with DLPack tensorsError Checking: Always use
TVM_FFI_CHECK_CUDA_ERRORto validate CUDA Runtime API results
Integration with Different Compilers#
The CUBIN launcher works with various CUDA compilation tools:
NVCC: Standard NVIDIA compiler, produces highly optimized CUBIN
NVRTC: Runtime compilation for JIT scenarios (via
tvm_ffi.cpp.nvrtc)Triton: High-level DSL that compiles to CUBIN
Custom compilers: Any tool that generates valid CUDA CUBIN
Complete Examples#
For complete working examples, see the examples/cubin_launcher/ directory:
embedded_cubin/- Pre-compiled CUBIN embedded at build timedynamic_cubin/- CUBIN data passed dynamically at runtimeexample_nvrtc_cubin.py- NVRTC runtime compilationexample_triton_cubin.py- Triton kernel compilation
These examples demonstrate:
Compiling CUDA kernels to CUBIN
Embedding CUBIN in C++ modules
Launching kernels with proper error handling
Testing and verification
API Reference#
C++ Classes#
tvm::ffi::CubinModule: RAII wrapper for CUBIN module lifecycletvm::ffi::CubinModule::CubinModule(): Load CUBIN from memorytvm::ffi::CubinModule::GetKernel(): Get kernel by nametvm::ffi::CubinModule::GetKernelWithMaxDynamicSharedMemory(): Get kernel by name with maximum dynamic shared memory settvm::ffi::CubinModule::operator[](): Convenient kernel access
tvm::ffi::CubinKernel: Handle for launching kernelstvm::ffi::CubinKernel::Launch(): Launch kernel with specified parameters
tvm::ffi::dim3: 3D dimension structuredim3(): Default (1, 1, 1)dim3(unsigned int x): 1Ddim3(unsigned int x, unsigned int y): 2Ddim3(unsigned int x, unsigned int y, unsigned int z): 3D
C++ Macros#
TVM_FFI_EMBED_CUBIN: Declare embedded CUBIN moduleTVM_FFI_EMBED_CUBIN_FROM_BYTES: Load CUBIN from byte arrayTVM_FFI_EMBED_CUBIN_GET_KERNEL: Get kernel from embedded moduleTVM_FFI_CHECK_CUDA_ERROR: Check CUDA Runtime API result
Python Functions#
tvm_ffi.cpp.nvrtc.nvrtc_compile(): Compile CUDA source to CUBINtvm_ffi.cpp.load_inline(): Load inline module with embedded CUBIN
Python Utilities#
python -m tvm_ffi.utils.embed_cubin: Command-line utility to embed CUBIN into object files--output-obj PATH: Output combined object file path--input-obj PATH: Input compiled object file containing C++ code withTVM_FFI_EMBED_CUBIN--cubin PATH: Input CUBIN file to embed--name NAME: Symbol name matchingTVM_FFI_EMBED_CUBIN(name)macro--verbose: Print detailed command output (optional)
tvm_ffi.utils.embed_cubin.embed_cubin(): Python API for embedding CUBINcubin_path: Path to input CUBIN fileinput_obj_path: Path to existing object fileoutput_obj_path: Path to output combined object filename: Symbol name for the embedded CUBINverbose: Enable detailed output (default: False)
CMake Functions#
add_tvm_ffi_cubin(<target> CUDA <source>): Compile CUDA source to CUBINadd_tvm_ffi_fatbin(<target> CUDA <source>): Compile CUDA source to FATBINtvm_ffi_embed_bin_into(<target> SYMBOL <symbol> BIN <bin_file>): Embed CUBIN/FATBIN into object target