CUBIN Launcher Guide#
This guide demonstrates how to load and launch CUDA kernels from CUBIN (CUDA Binary) modules using TVM-FFI. The CUBIN launcher enables you to execute pre-compiled or runtime-compiled CUDA kernels efficiently through the CUDA Runtime API.
Overview#
TVM-FFI provides utilities for loading and launching CUDA kernels from CUBIN modules. The implementation is in tvm/ffi/extra/cuda/cubin_launcher.h and provides:
tvm::ffi::CubinModule: RAII wrapper for loading CUBIN modules from memorytvm::ffi::CubinKernel: Handle for launching CUDA kernels with specified parametersTVM_FFI_EMBED_CUBIN: Macro for embedding CUBIN data at compile timeTVM_FFI_EMBED_CUBIN_GET_KERNEL: Macro for retrieving kernels from embedded CUBIN
The CUBIN launcher supports:
Loading CUBIN from memory (embedded data or runtime-generated)
Multi-GPU execution using CUDA primary contexts
Kernel parameter management and launch configuration
Integration with NVRTC, Triton, and other CUDA compilation tools
Build Integration:
TVM-FFI provides convenient tools for embedding CUBIN data at build time:
CMake utilities (
cmake/Utils/EmbedCubin.cmake): Functions for compiling CUDA to CUBIN and embedding it into C++ codePython utility (
python -m tvm_ffi.utils.embed_cubin): Command-line tool for embedding CUBIN into object filesPython API (
tvm_ffi.cpp.load_inline()): Runtime embedding viaembed_cubinparameter
Python Usage#
Basic Workflow#
The typical workflow for launching CUBIN kernels from Python involves:
Generate CUBIN: Compile your CUDA kernel to CUBIN format
Define C++ Wrapper: Write C++ code to load and launch the kernel
Load Module: Use
tvm_ffi.cpp.load_inline()withembed_cubinparameterCall Kernel: Invoke the kernel function from Python
Example: NVRTC Compilation#
Here’s a complete example using NVRTC to compile CUDA source at runtime.
Step 1: Compile CUDA source to CUBIN using NVRTC
# Define CUDA kernels
cuda_source = """
extern "C" __global__ void add_one(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1.0f;
}
}
extern "C" __global__ void mul_two(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] * 2.0f;
}
}
"""
# Compile CUDA source to CUBIN using NVRTC
print("Compiling CUDA kernels to CUBIN using NVRTC...")
cubin_bytes = nvrtc.nvrtc_compile(cuda_source, name="kernels.cu")
print(f"Compiled CUBIN: {len(cubin_bytes)} bytes\n")
Step 2: Define C++ wrapper with embedded CUBIN
# Define C++ code inline to launch the CUDA kernels using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
// Embed CUBIN module with name "nvrtc_cubin"
TVM_FFI_EMBED_CUBIN(nvrtc_cubin);
namespace nvrtc_loader {
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "add_one");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "mul_two");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
} // namespace nvrtc_loader
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, nvrtc_loader::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, nvrtc_loader::MulTwo);
"""
print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
mod = cpp.load_inline(
"nvrtc_loader",
cuda_sources=sources,
embed_cubin={"nvrtc_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")
Key Points:
The
embed_cubinparameter is a dictionary mapping CUBIN names to their binary dataCUBIN names in
embed_cubinmust match names inTVM_FFI_EMBED_CUBINUse
cuda_sourcesparameter (instead ofcpp_sources) to automatically link with CUDA librariesThe C++ wrapper handles device management, stream handling, and kernel launching
Example: Using Triton Kernels#
You can compile Triton kernels to CUBIN and launch them through TVM-FFI.
Step 1: Define and compile Triton kernel
# Define the kernel dynamically
@triton.jit
def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024): # noqa
pid = tl.program_id(0)
start = pid * BLOCK
offsets = start + tl.arange(0, BLOCK)
mask = offsets < n
x = tl.load(X_ptr + offsets, mask=mask, other=0.0)
y = x * x
tl.store(Y_ptr + offsets, y, mask=mask)
# Trigger kernel compilation by doing a dummy call
x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
square_kernel[1, 1](x_dummy, y_dummy, 1024)
# Extract compiled CUBIN from the device cache
device_caches = square_kernel.device_caches
device_id = next(iter(device_caches.keys()))
cache_tuple = device_caches[device_id]
compiled_kernel = next(iter(cache_tuple[0].values()))
# Get CUBIN bytes
cubin_bytes = compiled_kernel.kernel
Step 2: Define C++ wrapper to launch the Triton kernel
# Define C++ code inline to load and launch the Triton kernel using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
// Embed CUBIN module with name "triton_cubin"
TVM_FFI_EMBED_CUBIN(triton_cubin);
namespace triton_loader {
void LaunchSquare(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_cubin, "square_kernel");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Sizes must match";
uint32_t n = static_cast<uint32_t>(x.size(0));
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
uint64_t dummy_ptr = 0;
// Workaround for Triton extra params: pass dummy addresses for unused parameters
void* args[] = {&x_ptr, &y_ptr, &n, &dummy_ptr, &dummy_ptr};
// Kernel was compiled with .reqntid 128, not 1024
tvm::ffi::dim3 grid((n + 127) / 128);
tvm::ffi::dim3 block(128);
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
} // namespace triton_loader
TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_square, triton_loader::LaunchSquare);
"""
print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
# Find CUDA include path
mod = cpp.load_inline(
"triton_loader",
cuda_sources=sources,
embed_cubin={"triton_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")
Note
Triton kernels may require extra dummy parameters in the argument list. Check the compiled kernel’s signature to determine the exact parameter count needed.
C++ Usage#
Embedding CUBIN at Compile Time#
The recommended approach in C++ is to embed CUBIN data directly into your shared library:
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
// Embed CUBIN module with name "env"
// This creates the necessary symbols and singleton struct for accessing the embedded CUBIN
TVM_FFI_EMBED_CUBIN(env);
namespace cubin_embedded {
/*!
* \brief Launch add_one_cuda kernel on input tensor.
* \param x Input tensor (float32, 1D)
* \param y Output tensor (float32, 1D, same shape as x)
*/
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "add_one_cuda");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
/*!
* \brief Launch mul_two_cuda kernel on input tensor.
* \param x Input tensor (float32, 1D)
* \param y Output tensor (float32, 1D, same shape as x)
*/
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "mul_two_cuda");
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
// Export TVM-FFI functions
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_embedded::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_embedded::MulTwo);
} // namespace cubin_embedded
Key Points:
Use
static auto kernelto cache the kernel lookup for efficiencyKernel arguments must be pointers to the actual values (use
&for addresses)tvm::ffi::dim3supports 1D, 2D, or 3D configurations:dim3(x),dim3(x, y),dim3(x, y, z)TVMFFIEnvGetStreamretrieves the correct CUDA stream for the deviceAlways check kernel launch results with
TVM_FFI_CHECK_CUDA_ERROR(which checks CUDA Runtime API errors)
Loading CUBIN at Runtime#
You can also load CUBIN modules dynamically from memory:
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/string.h>
#include <cstdint>
#include <memory>
namespace cubin_dynamic {
// Global CUBIN module and kernels (loaded dynamically)
static std::unique_ptr<tvm::ffi::CubinModule> g_cubin_module;
static std::unique_ptr<tvm::ffi::CubinKernel> g_add_one_kernel;
static std::unique_ptr<tvm::ffi::CubinKernel> g_mul_two_kernel;
/*!
* \brief Set CUBIN module from binary data.
* \param cubin CUBIN binary data as Bytes object.
*/
void SetCubin(const tvm::ffi::Bytes& cubin) {
// Load CUBIN module from memory
g_cubin_module = std::make_unique<tvm::ffi::CubinModule>(cubin);
g_add_one_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["add_one_cuda"]);
g_mul_two_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["mul_two_cuda"]);
}
/*!
* \brief Launch add_one_cuda kernel on input tensor.
* \param x Input tensor (float32, 1D)
* \param y Output tensor (float32, 1D, same shape as x)
*/
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
TVM_FFI_CHECK(g_cubin_module != nullptr, RuntimeError)
<< "CUBIN module not loaded. Call set_cubin first.";
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = g_add_one_kernel->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
/*!
* \brief Launch mul_two_cuda kernel on input tensor.
* \param x Input tensor (float32, 1D)
* \param y Output tensor (float32, 1D, same shape as x)
*/
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
TVM_FFI_CHECK(g_cubin_module != nullptr, RuntimeError)
<< "CUBIN module not loaded. Call set_cubin first.";
TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";
int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
reinterpret_cast<void*>(&n)};
// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);
// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
// Launch kernel
cudaError_t result = g_mul_two_kernel->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}
// Export TVM-FFI functions
TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_cubin, cubin_dynamic::SetCubin);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_dynamic::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_dynamic::MulTwo);
} // namespace cubin_dynamic
Embedding CUBIN with CMake Utilities#
TVM-FFI provides CMake utility functions that simplify the CUBIN embedding process. This is the recommended approach for CMake-based projects.
Using CMake Utilities:
# -arch=native to automatically detect the GPU architecture
tvm_ffi_generate_cubin(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin SOURCE src/kernel.cu ARCH native
)
# Step 2: Embed CUBIN into the object file using tvm_ffi_embed_cubin utility This creates symbols:
# __tvm_ffi__cubin_env, __tvm_ffi__cubin_env_end (local)
tvm_ffi_embed_cubin(
OUTPUT
${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o
SOURCE
src/lib_embedded.cc
CUBIN
${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin
NAME
env
)
# Step 3: Build lib_embedded shared library (with embedded CUBIN)
add_library(lib_embedded SHARED ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o)
target_link_libraries(lib_embedded PRIVATE tvm_ffi_header tvm_ffi_shared CUDA::cudart)
set_target_properties(
lib_embedded
PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/"
PREFIX ""
SUFFIX ".so"
LINKER_LANGUAGE CXX
)
Available CMake Functions:
tvm_ffi_generate_cubin(): Compiles CUDA source to CUBIN using nvccOUTPUT: Path to output CUBIN fileSOURCE: Path to CUDA source fileARCH: Target GPU architecture (default:nativefor auto-detection)OPTIONS: Additional nvcc compiler options (optional)DEPENDS: Additional dependencies (optional)
tvm_ffi_embed_cubin(): Compiles C++ source and embeds CUBIN dataOUTPUT: Path to output combined object fileSOURCE: Path to C++ source file withTVM_FFI_EMBED_CUBINmacroCUBIN: Path to CUBIN file to embedNAME: Symbol name used inTVM_FFI_EMBED_CUBIN(name)macroDEPENDS: Additional dependencies (optional)
The utilities automatically handle:
Compiling C++ source to intermediate object file
Creating CUBIN symbols with proper naming
Merging object files using
ld -rAdding
.note.GNU-stacksection for securityLocalizing symbols to prevent conflicts
Embedding CUBIN with Python Utility#
For more advanced use cases or non-CMake build systems, you can use the Python command-line utility to embed CUBIN data into existing object files.
Command-Line Usage:
# Step 1: Compile C++ source to object file
g++ -c -fPIC -std=c++17 -I/path/to/tvm-ffi/include mycode.cc -o mycode.o
# Step 2: Embed CUBIN into the object file
python -m tvm_ffi.utils.embed_cubin \
--output-obj mycode_with_cubin.o \
--input-obj mycode.o \
--cubin kernel.cubin \
--name my_kernels
# Step 3: Link into final library
g++ -o mylib.so -shared mycode_with_cubin.o -lcudart
Python API:
from pathlib import Path
from tvm_ffi.utils.embed_cubin import embed_cubin
embed_cubin(
cubin_path=Path("kernel.cubin"),
input_obj_path=Path("mycode.o"),
output_obj_path=Path("mycode_with_cubin.o"),
name="my_kernels",
verbose=True # Optional: print detailed progress
)
The Python utility performs these steps:
Creates intermediate CUBIN object file using
ld -r -b binaryAdds
.note.GNU-stacksection for securityRenames symbols to match TVM-FFI format (
__tvm_ffi__cubin_<name>)Merges with input object file using relocatable linking
Localizes symbols to prevent conflicts when multiple object files use the same name
Manual CUBIN Embedding#
For reference, here’s how to manually embed CUBIN using objcopy and ld:
Step 1: Compile CUDA kernel to CUBIN
nvcc --cubin -arch=sm_75 kernel.cu -o kernel.cubin
Step 2: Convert CUBIN to object file
ld -r -b binary -o kernel_data.o kernel.cubin
Step 3: Rename symbols with objcopy
objcopy --rename-section .data=.rodata,alloc,load,readonly,data,contents \
--redefine-sym _binary_kernel_cubin_start=__tvm_ffi__cubin_my_kernels \
--redefine-sym _binary_kernel_cubin_end=__tvm_ffi__cubin_my_kernels_end \
kernel_data.o
Step 4: Link with your library
g++ -o mylib.so -shared mycode.cc kernel_data.o -Wl,-z,noexecstack -lcudart
The symbol names must match the name used in TVM_FFI_EMBED_CUBIN.
When to Use Each Approach:
CMake utilities: Best for CMake-based projects, provides cleanest integration (recommended)
Python utility: Best for custom build systems, Makefile-based projects, or advanced workflows (recommended)
Manual objcopy: Low-level approach, useful for understanding the process or debugging (only for customized use cases)
Advanced Topics#
Multi-GPU Support#
The CUBIN launcher automatically handles multi-GPU execution through CUDA primary contexts. Kernels will execute on the device associated with the input tensors:
void MultiGPUExample(tvm::ffi::TensorView x_gpu0, tvm::ffi::TensorView x_gpu1) {
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "process");
// Launch on GPU 0 (device determined by x_gpu0.device())
LaunchOnDevice(kernel, x_gpu0);
// Launch on GPU 1 (device determined by x_gpu1.device())
LaunchOnDevice(kernel, x_gpu1);
}
The tvm::ffi::CubinKernel automatically uses the device context from the input tensors.
Kernel Launch Configuration#
When writing the C++ wrapper, important considerations include:
Grid/Block Dimensions: Use
tvm::ffi::dim3for 1D, 2D, or 3D configurations1D:
dim3(x)→(x, 1, 1)2D:
dim3(x, y)→(x, y, 1)3D:
dim3(x, y, z)→(x, y, z)
Kernel Arguments: Must be pointers to actual values
For device pointers:
void* ptr = tensor.data_ptr(); args[] = {&ptr}For scalars:
int n = 42; args[] = {&n}
Stream Management: Use
TVMFFIEnvGetStreamto get the correct CUDA stream for synchronization with DLPack tensorsError Checking: Always use
TVM_FFI_CHECK_CUDA_ERRORto validate CUDA Runtime API results
Integration with Different Compilers#
The CUBIN launcher works with various CUDA compilation tools:
NVCC: Standard NVIDIA compiler, produces highly optimized CUBIN
NVRTC: Runtime compilation for JIT scenarios (via
tvm_ffi.cpp.nvrtc)Triton: High-level DSL that compiles to CUBIN
Custom compilers: Any tool that generates valid CUDA CUBIN
Complete Examples#
For complete working examples, see the examples/cubin_launcher/ directory:
embedded_cubin/- Pre-compiled CUBIN embedded at build timedynamic_cubin/- CUBIN data passed dynamically at runtimeexample_nvrtc_cubin.py- NVRTC runtime compilationexample_triton_cubin.py- Triton kernel compilation
These examples demonstrate:
Compiling CUDA kernels to CUBIN
Embedding CUBIN in C++ modules
Launching kernels with proper error handling
Testing and verification
API Reference#
C++ Classes#
tvm::ffi::CubinModule: RAII wrapper for CUBIN module lifecycletvm::ffi::CubinModule::CubinModule(): Load CUBIN from memorytvm::ffi::CubinModule::GetKernel(): Get kernel by nametvm::ffi::CubinModule::GetKernelWithMaxDynamicSharedMemory(): Get kernel by name with maximum dynamic shared memory settvm::ffi::CubinModule::operator[](): Convenient kernel access
tvm::ffi::CubinKernel: Handle for launching kernelstvm::ffi::CubinKernel::Launch(): Launch kernel with specified parameters
tvm::ffi::dim3: 3D dimension structuredim3(): Default (1, 1, 1)dim3(unsigned int x): 1Ddim3(unsigned int x, unsigned int y): 2Ddim3(unsigned int x, unsigned int y, unsigned int z): 3D
C++ Macros#
TVM_FFI_EMBED_CUBIN: Declare embedded CUBIN moduleTVM_FFI_EMBED_CUBIN_GET_KERNEL: Get kernel from embedded moduleTVM_FFI_CHECK_CUDA_ERROR: Check CUDA Runtime API result
Python Functions#
tvm_ffi.cpp.nvrtc.nvrtc_compile(): Compile CUDA source to CUBINtvm_ffi.cpp.load_inline(): Load inline module with embedded CUBIN
Python Utilities#
python -m tvm_ffi.utils.embed_cubin: Command-line utility to embed CUBIN into object files--output-obj PATH: Output combined object file path--input-obj PATH: Input object file containing C++ code withTVM_FFI_EMBED_CUBIN--cubin PATH: Input CUBIN file to embed--name NAME: Symbol name matchingTVM_FFI_EMBED_CUBIN(name)macro--verbose: Print detailed command output (optional)
tvm_ffi.utils.embed_cubin.embed_cubin(): Python API for embedding CUBINcubin_path: Path to input CUBIN fileinput_obj_path: Path to existing object fileoutput_obj_path: Path to output combined object filename: Symbol name for the embedded CUBINverbose: Enable detailed output (default: False)
CMake Functions#
tvm_ffi_generate_cubin(): Compile CUDA source to CUBINOUTPUT: Path to output CUBIN fileSOURCE: Path to CUDA source file (.cu)ARCH: Target architecture (default:native)OPTIONS: Additional nvcc compiler flags (optional)DEPENDS: Additional dependencies (optional)
tvm_ffi_embed_cubin(): Compile C++ source and embed CUBIN dataOUTPUT: Path to output combined object fileSOURCE: Path to C++ source fileCUBIN: Path to CUBIN file to embedNAME: Symbol name matchingTVM_FFI_EMBED_CUBIN(name)in sourceDEPENDS: Additional dependencies (optional)