CUBIN Launcher Guide#

This guide demonstrates how to load and launch CUDA kernels from CUBIN (CUDA Binary) modules using TVM-FFI. The CUBIN launcher enables you to execute pre-compiled or runtime-compiled CUDA kernels efficiently through the CUDA Runtime API.

Overview#

TVM-FFI provides utilities for loading and launching CUDA kernels from CUBIN modules. The implementation is in tvm/ffi/extra/cuda/cubin_launcher.h and provides:

The CUBIN launcher supports:

  • Loading CUBIN from memory (embedded data or runtime-generated)

  • Multi-GPU execution using CUDA primary contexts

  • Kernel parameter management and launch configuration

  • Integration with NVRTC, Triton, and other CUDA compilation tools

Build Integration:

TVM-FFI provides convenient tools for embedding CUBIN data at build time:

  • CMake utilities (cmake/Utils/EmbedCubin.cmake): Functions for compiling CUDA to CUBIN and embedding it into C++ code

  • Python utility (python -m tvm_ffi.utils.embed_cubin): Command-line tool for embedding CUBIN into object files

  • Python API (tvm_ffi.cpp.load_inline()): Runtime embedding via embed_cubin parameter

Python Usage#

Basic Workflow#

The typical workflow for launching CUBIN kernels from Python involves:

  1. Generate CUBIN: Compile your CUDA kernel to CUBIN format

  2. Define C++ Wrapper: Write C++ code to load and launch the kernel

  3. Load Module: Use tvm_ffi.cpp.load_inline() with embed_cubin parameter

  4. Call Kernel: Invoke the kernel function from Python

Example: NVRTC Compilation#

Here’s a complete example using NVRTC to compile CUDA source at runtime.

Step 1: Compile CUDA source to CUBIN using NVRTC

# Define CUDA kernels
cuda_source = """
extern "C" __global__ void add_one(float* x, float* y, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        y[idx] = x[idx] + 1.0f;
    }
}

extern "C" __global__ void mul_two(float* x, float* y, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        y[idx] = x[idx] * 2.0f;
    }
}
"""

# Compile CUDA source to CUBIN using NVRTC
print("Compiling CUDA kernels to CUBIN using NVRTC...")
cubin_bytes = nvrtc.nvrtc_compile(cuda_source, name="kernels.cu")
print(f"Compiled CUBIN: {len(cubin_bytes)} bytes\n")

Step 2: Define C++ wrapper with embedded CUBIN

# Define C++ code inline to launch the CUDA kernels using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>

// Embed CUBIN module with name "nvrtc_cubin"
TVM_FFI_EMBED_CUBIN(nvrtc_cubin);

namespace nvrtc_loader {

void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "add_one");

TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();

// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                reinterpret_cast<void*>(&n)};

// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);

// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}

void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(nvrtc_cubin, "mul_two");

TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

int64_t n = x.size(0);
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();

// Prepare kernel arguments
void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                reinterpret_cast<void*>(&n)};

// Launch configuration
tvm::ffi::dim3 grid((n + 255) / 256);
tvm::ffi::dim3 block(256);

// Get CUDA stream
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}

}  // namespace nvrtc_loader

TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, nvrtc_loader::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, nvrtc_loader::MulTwo);
"""

print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
mod = cpp.load_inline(
    "nvrtc_loader",
    cuda_sources=sources,
    embed_cubin={"nvrtc_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")

Key Points:

  • The embed_cubin parameter is a dictionary mapping CUBIN names to their binary data

  • CUBIN names in embed_cubin must match names in TVM_FFI_EMBED_CUBIN

  • Use cuda_sources parameter (instead of cpp_sources) to automatically link with CUDA libraries

  • The C++ wrapper handles device management, stream handling, and kernel launching

Example: Using Triton Kernels#

You can compile Triton kernels to CUBIN and launch them through TVM-FFI.

Step 1: Define and compile Triton kernel

# Define the kernel dynamically
@triton.jit
def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024):  # noqa
    pid = tl.program_id(0)
    start = pid * BLOCK
    offsets = start + tl.arange(0, BLOCK)
    mask = offsets < n
    x = tl.load(X_ptr + offsets, mask=mask, other=0.0)
    y = x * x
    tl.store(Y_ptr + offsets, y, mask=mask)

# Trigger kernel compilation by doing a dummy call
x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
square_kernel[1, 1](x_dummy, y_dummy, 1024)

# Extract compiled CUBIN from the device cache
device_caches = square_kernel.device_caches
device_id = next(iter(device_caches.keys()))
cache_tuple = device_caches[device_id]
compiled_kernel = next(iter(cache_tuple[0].values()))

# Get CUBIN bytes
cubin_bytes = compiled_kernel.kernel

Step 2: Define C++ wrapper to launch the Triton kernel

# Define C++ code inline to load and launch the Triton kernel using embedded CUBIN
sources = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>

// Embed CUBIN module with name "triton_cubin"
TVM_FFI_EMBED_CUBIN(triton_cubin);

namespace triton_loader {

void LaunchSquare(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Get kernel from embedded CUBIN (cached in static variable for efficiency)
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_cubin, "square_kernel");

TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Sizes must match";

uint32_t n = static_cast<uint32_t>(x.size(0));
void* x_ptr = x.data_ptr();
void* y_ptr = y.data_ptr();
uint64_t dummy_ptr = 0;

// Workaround for Triton extra params: pass dummy addresses for unused parameters
void* args[] = {&x_ptr, &y_ptr, &n, &dummy_ptr, &dummy_ptr};

// Kernel was compiled with .reqntid 128, not 1024
tvm::ffi::dim3 grid((n + 127) / 128);
tvm::ffi::dim3 block(128);

DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
}

}  // namespace triton_loader

TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_square, triton_loader::LaunchSquare);
"""

print("Compiling C++ sources with tvm_ffi.cpp.load_inline...")
# Find CUDA include path
mod = cpp.load_inline(
    "triton_loader",
    cuda_sources=sources,
    embed_cubin={"triton_cubin": cubin_bytes},
)
print("Successfully compiled and loaded C++ sources with embedded CUBIN\n")

Note

Triton kernels may require extra dummy parameters in the argument list. Check the compiled kernel’s signature to determine the exact parameter count needed.

C++ Usage#

Embedding CUBIN at Compile Time#

The recommended approach in C++ is to embed CUBIN data directly into your shared library:

#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>

// Embed CUBIN module with name "env"
// This creates the necessary symbols and singleton struct for accessing the embedded CUBIN
TVM_FFI_EMBED_CUBIN(env);

namespace cubin_embedded {

/*!
 * \brief Launch add_one_cuda kernel on input tensor.
 * \param x Input tensor (float32, 1D)
 * \param y Output tensor (float32, 1D, same shape as x)
 */
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  // Get kernel from embedded CUBIN (cached in static variable for efficiency)
  static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "add_one_cuda");

  TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
  TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
  TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

  int64_t n = x.size(0);
  void* x_ptr = x.data_ptr();
  void* y_ptr = y.data_ptr();

  // Prepare kernel arguments
  void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                  reinterpret_cast<void*>(&n)};

  // Launch configuration
  tvm::ffi::dim3 grid((n + 255) / 256);
  tvm::ffi::dim3 block(256);

  // Get CUDA stream
  DLDevice device = x.device();
  cudaStream_t stream =
      static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

  // Launch kernel
  cudaError_t result = kernel.Launch(args, grid, block, stream);
  TVM_FFI_CHECK_CUDA_ERROR(result);
}

/*!
 * \brief Launch mul_two_cuda kernel on input tensor.
 * \param x Input tensor (float32, 1D)
 * \param y Output tensor (float32, 1D, same shape as x)
 */
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  // Get kernel from embedded CUBIN (cached in static variable for efficiency)
  static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "mul_two_cuda");

  TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
  TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
  TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

  int64_t n = x.size(0);
  void* x_ptr = x.data_ptr();
  void* y_ptr = y.data_ptr();

  // Prepare kernel arguments
  void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                  reinterpret_cast<void*>(&n)};

  // Launch configuration
  tvm::ffi::dim3 grid((n + 255) / 256);
  tvm::ffi::dim3 block(256);

  // Get CUDA stream
  DLDevice device = x.device();
  cudaStream_t stream =
      static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

  // Launch kernel
  cudaError_t result = kernel.Launch(args, grid, block, stream);
  TVM_FFI_CHECK_CUDA_ERROR(result);
}

// Export TVM-FFI functions
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_embedded::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_embedded::MulTwo);

}  // namespace cubin_embedded

Key Points:

  • Use static auto kernel to cache the kernel lookup for efficiency

  • Kernel arguments must be pointers to the actual values (use & for addresses)

  • tvm::ffi::dim3 supports 1D, 2D, or 3D configurations: dim3(x), dim3(x, y), dim3(x, y, z)

  • TVMFFIEnvGetStream retrieves the correct CUDA stream for the device

  • Always check kernel launch results with TVM_FFI_CHECK_CUDA_ERROR (which checks CUDA Runtime API errors)

Loading CUBIN at Runtime#

You can also load CUBIN modules dynamically from memory:

#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/string.h>

#include <cstdint>
#include <memory>

namespace cubin_dynamic {

// Global CUBIN module and kernels (loaded dynamically)
static std::unique_ptr<tvm::ffi::CubinModule> g_cubin_module;
static std::unique_ptr<tvm::ffi::CubinKernel> g_add_one_kernel;
static std::unique_ptr<tvm::ffi::CubinKernel> g_mul_two_kernel;

/*!
 * \brief Set CUBIN module from binary data.
 * \param cubin CUBIN binary data as Bytes object.
 */
void SetCubin(const tvm::ffi::Bytes& cubin) {
  // Load CUBIN module from memory
  g_cubin_module = std::make_unique<tvm::ffi::CubinModule>(cubin);
  g_add_one_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["add_one_cuda"]);
  g_mul_two_kernel = std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["mul_two_cuda"]);
}

/*!
 * \brief Launch add_one_cuda kernel on input tensor.
 * \param x Input tensor (float32, 1D)
 * \param y Output tensor (float32, 1D, same shape as x)
 */
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  TVM_FFI_CHECK(g_cubin_module != nullptr, RuntimeError)
      << "CUBIN module not loaded. Call set_cubin first.";

  TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
  TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
  TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

  int64_t n = x.size(0);
  void* x_ptr = x.data_ptr();
  void* y_ptr = y.data_ptr();

  // Prepare kernel arguments
  void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                  reinterpret_cast<void*>(&n)};

  // Launch configuration
  tvm::ffi::dim3 grid((n + 255) / 256);
  tvm::ffi::dim3 block(256);

  // Get CUDA stream
  DLDevice device = x.device();
  cudaStream_t stream =
      static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

  // Launch kernel
  cudaError_t result = g_add_one_kernel->Launch(args, grid, block, stream);
  TVM_FFI_CHECK_CUDA_ERROR(result);
}

/*!
 * \brief Launch mul_two_cuda kernel on input tensor.
 * \param x Input tensor (float32, 1D)
 * \param y Output tensor (float32, 1D, same shape as x)
 */
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  TVM_FFI_CHECK(g_cubin_module != nullptr, RuntimeError)
      << "CUBIN module not loaded. Call set_cubin first.";

  TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
  TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
  TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size";

  int64_t n = x.size(0);
  void* x_ptr = x.data_ptr();
  void* y_ptr = y.data_ptr();

  // Prepare kernel arguments
  void* args[] = {reinterpret_cast<void*>(&x_ptr), reinterpret_cast<void*>(&y_ptr),
                  reinterpret_cast<void*>(&n)};

  // Launch configuration
  tvm::ffi::dim3 grid((n + 255) / 256);
  tvm::ffi::dim3 block(256);

  // Get CUDA stream
  DLDevice device = x.device();
  cudaStream_t stream =
      static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

  // Launch kernel
  cudaError_t result = g_mul_two_kernel->Launch(args, grid, block, stream);
  TVM_FFI_CHECK_CUDA_ERROR(result);
}

// Export TVM-FFI functions
TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_cubin, cubin_dynamic::SetCubin);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_dynamic::AddOne);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_dynamic::MulTwo);

}  // namespace cubin_dynamic

Embedding CUBIN with CMake Utilities#

TVM-FFI provides CMake utility functions that simplify the CUBIN embedding process. This is the recommended approach for CMake-based projects.

Using CMake Utilities:

# -arch=native to automatically detect the GPU architecture
tvm_ffi_generate_cubin(
  OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin SOURCE src/kernel.cu ARCH native
)

# Step 2: Embed CUBIN into the object file using tvm_ffi_embed_cubin utility This creates symbols:
# __tvm_ffi__cubin_env, __tvm_ffi__cubin_env_end (local)
tvm_ffi_embed_cubin(
  OUTPUT
  ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o
  SOURCE
  src/lib_embedded.cc
  CUBIN
  ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin
  NAME
  env
)

# Step 3: Build lib_embedded shared library (with embedded CUBIN)
add_library(lib_embedded SHARED ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o)
target_link_libraries(lib_embedded PRIVATE tvm_ffi_header tvm_ffi_shared CUDA::cudart)
set_target_properties(
  lib_embedded
  PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/"
             PREFIX ""
             SUFFIX ".so"
             LINKER_LANGUAGE CXX
)

Available CMake Functions:

  • tvm_ffi_generate_cubin(): Compiles CUDA source to CUBIN using nvcc

    • OUTPUT: Path to output CUBIN file

    • SOURCE: Path to CUDA source file

    • ARCH: Target GPU architecture (default: native for auto-detection)

    • OPTIONS: Additional nvcc compiler options (optional)

    • DEPENDS: Additional dependencies (optional)

  • tvm_ffi_embed_cubin(): Compiles C++ source and embeds CUBIN data

    • OUTPUT: Path to output combined object file

    • SOURCE: Path to C++ source file with TVM_FFI_EMBED_CUBIN macro

    • CUBIN: Path to CUBIN file to embed

    • NAME: Symbol name used in TVM_FFI_EMBED_CUBIN(name) macro

    • DEPENDS: Additional dependencies (optional)

The utilities automatically handle:

  • Compiling C++ source to intermediate object file

  • Creating CUBIN symbols with proper naming

  • Merging object files using ld -r

  • Adding .note.GNU-stack section for security

  • Localizing symbols to prevent conflicts

Embedding CUBIN with Python Utility#

For more advanced use cases or non-CMake build systems, you can use the Python command-line utility to embed CUBIN data into existing object files.

Command-Line Usage:

# Step 1: Compile C++ source to object file
g++ -c -fPIC -std=c++17 -I/path/to/tvm-ffi/include mycode.cc -o mycode.o

# Step 2: Embed CUBIN into the object file
python -m tvm_ffi.utils.embed_cubin \
    --output-obj mycode_with_cubin.o \
    --input-obj mycode.o \
    --cubin kernel.cubin \
    --name my_kernels

# Step 3: Link into final library
g++ -o mylib.so -shared mycode_with_cubin.o -lcudart

Python API:

from pathlib import Path
from tvm_ffi.utils.embed_cubin import embed_cubin

embed_cubin(
    cubin_path=Path("kernel.cubin"),
    input_obj_path=Path("mycode.o"),
    output_obj_path=Path("mycode_with_cubin.o"),
    name="my_kernels",
    verbose=True  # Optional: print detailed progress
)

The Python utility performs these steps:

  1. Creates intermediate CUBIN object file using ld -r -b binary

  2. Adds .note.GNU-stack section for security

  3. Renames symbols to match TVM-FFI format (__tvm_ffi__cubin_<name>)

  4. Merges with input object file using relocatable linking

  5. Localizes symbols to prevent conflicts when multiple object files use the same name

Manual CUBIN Embedding#

For reference, here’s how to manually embed CUBIN using objcopy and ld:

Step 1: Compile CUDA kernel to CUBIN

nvcc --cubin -arch=sm_75 kernel.cu -o kernel.cubin

Step 2: Convert CUBIN to object file

ld -r -b binary -o kernel_data.o kernel.cubin

Step 3: Rename symbols with objcopy

objcopy --rename-section .data=.rodata,alloc,load,readonly,data,contents \
        --redefine-sym _binary_kernel_cubin_start=__tvm_ffi__cubin_my_kernels \
        --redefine-sym _binary_kernel_cubin_end=__tvm_ffi__cubin_my_kernels_end \
        kernel_data.o

Step 4: Link with your library

g++ -o mylib.so -shared mycode.cc kernel_data.o -Wl,-z,noexecstack -lcudart

The symbol names must match the name used in TVM_FFI_EMBED_CUBIN.

When to Use Each Approach:

  • CMake utilities: Best for CMake-based projects, provides cleanest integration (recommended)

  • Python utility: Best for custom build systems, Makefile-based projects, or advanced workflows (recommended)

  • Manual objcopy: Low-level approach, useful for understanding the process or debugging (only for customized use cases)

Advanced Topics#

Multi-GPU Support#

The CUBIN launcher automatically handles multi-GPU execution through CUDA primary contexts. Kernels will execute on the device associated with the input tensors:

void MultiGPUExample(tvm::ffi::TensorView x_gpu0, tvm::ffi::TensorView x_gpu1) {
  static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "process");

  // Launch on GPU 0 (device determined by x_gpu0.device())
  LaunchOnDevice(kernel, x_gpu0);

  // Launch on GPU 1 (device determined by x_gpu1.device())
  LaunchOnDevice(kernel, x_gpu1);
}

The tvm::ffi::CubinKernel automatically uses the device context from the input tensors.

Kernel Launch Configuration#

When writing the C++ wrapper, important considerations include:

  • Grid/Block Dimensions: Use tvm::ffi::dim3 for 1D, 2D, or 3D configurations

    • 1D: dim3(x)(x, 1, 1)

    • 2D: dim3(x, y)(x, y, 1)

    • 3D: dim3(x, y, z)(x, y, z)

  • Kernel Arguments: Must be pointers to actual values

    • For device pointers: void* ptr = tensor.data_ptr(); args[] = {&ptr}

    • For scalars: int n = 42; args[] = {&n}

  • Stream Management: Use TVMFFIEnvGetStream to get the correct CUDA stream for synchronization with DLPack tensors

  • Error Checking: Always use TVM_FFI_CHECK_CUDA_ERROR to validate CUDA Runtime API results

Dynamic Shared Memory#

To use dynamic shared memory, specify the size in the tvm::ffi::CubinKernel::Launch() call:

// Allocate 1KB of dynamic shared memory
uint32_t shared_mem_bytes = 1024;
cudaError_t result = kernel.Launch(args, grid, block, stream, shared_mem_bytes);

Integration with Different Compilers#

The CUBIN launcher works with various CUDA compilation tools:

  • NVCC: Standard NVIDIA compiler, produces highly optimized CUBIN

  • NVRTC: Runtime compilation for JIT scenarios (via tvm_ffi.cpp.nvrtc)

  • Triton: High-level DSL that compiles to CUBIN

  • Custom compilers: Any tool that generates valid CUDA CUBIN

Complete Examples#

For complete working examples, see the examples/cubin_launcher/ directory:

  • embedded_cubin/ - Pre-compiled CUBIN embedded at build time

  • dynamic_cubin/ - CUBIN data passed dynamically at runtime

  • example_nvrtc_cubin.py - NVRTC runtime compilation

  • example_triton_cubin.py - Triton kernel compilation

These examples demonstrate:

  • Compiling CUDA kernels to CUBIN

  • Embedding CUBIN in C++ modules

  • Launching kernels with proper error handling

  • Testing and verification

API Reference#

C++ Classes#

C++ Macros#

Python Functions#

Python Utilities#

  • python -m tvm_ffi.utils.embed_cubin: Command-line utility to embed CUBIN into object files

    • --output-obj PATH: Output combined object file path

    • --input-obj PATH: Input object file containing C++ code with TVM_FFI_EMBED_CUBIN

    • --cubin PATH: Input CUBIN file to embed

    • --name NAME: Symbol name matching TVM_FFI_EMBED_CUBIN(name) macro

    • --verbose: Print detailed command output (optional)

  • tvm_ffi.utils.embed_cubin.embed_cubin(): Python API for embedding CUBIN

    • cubin_path: Path to input CUBIN file

    • input_obj_path: Path to existing object file

    • output_obj_path: Path to output combined object file

    • name: Symbol name for the embedded CUBIN

    • verbose: Enable detailed output (default: False)

CMake Functions#

  • tvm_ffi_generate_cubin(): Compile CUDA source to CUBIN

    • OUTPUT: Path to output CUBIN file

    • SOURCE: Path to CUDA source file (.cu)

    • ARCH: Target architecture (default: native)

    • OPTIONS: Additional nvcc compiler flags (optional)

    • DEPENDS: Additional dependencies (optional)

  • tvm_ffi_embed_cubin(): Compile C++ source and embed CUBIN data

    • OUTPUT: Path to output combined object file

    • SOURCE: Path to C++ source file

    • CUBIN: Path to CUBIN file to embed

    • NAME: Symbol name matching TVM_FFI_EMBED_CUBIN(name) in source

    • DEPENDS: Additional dependencies (optional)