Kernel Library Guide#

This guide covers shipping C++/CUDA kernel libraries with TVM-FFI. The resulting libraries are agnostic to Python version and ML framework — a single .so works with PyTorch, JAX, PaddlePaddle, NumPy, and more.

See also

Anatomy of a Kernel Function#

Every TVM-FFI CUDA kernel follows the same sequence:

  1. Validate inputs (device, dtype, shape, contiguity)

  2. Set device guard to match the tensor’s device

  3. Acquire stream from the host framework

  4. Dispatch on dtype and launch the kernel

Here is a complete Scale kernel that computes y = x * factor:

void Scale(TensorView output, TensorView input, double factor) {
  // --- 1. Validate inputs ---
  CHECK_INPUT(input);
  CHECK_INPUT(output);
  CHECK_DIM(1, input);
  CHECK_DEVICE(input, output);
  TVM_FFI_CHECK(input.dtype() == output.dtype(), ValueError) << "input/output dtype mismatch";
  TVM_FFI_CHECK(input.numel() == output.numel(), ValueError) << "input/output size mismatch";

  // --- 2. Device guard and stream ---
  ffi::CUDADeviceGuard guard(input.device().device_id);
  cudaStream_t stream = get_cuda_stream(input.device());

  // --- 3. Dispatch on dtype and launch ---
  int64_t n = input.numel();
  int threads = 256;
  int blocks = (n + threads - 1) / threads;

  if (input.dtype() == dl_float32) {
    ScaleKernel<<<blocks, threads, 0, stream>>>(static_cast<float*>(output.data_ptr()),
                                                static_cast<float*>(input.data_ptr()),
                                                static_cast<float>(factor), n);
  } else if (input.dtype() == dl_float16) {
    ScaleKernel<<<blocks, threads, 0, stream>>>(static_cast<half*>(output.data_ptr()),
                                                static_cast<half*>(input.data_ptr()),
                                                static_cast<half>(factor), n);
  } else {
    TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
  }
}

The CUDA kernel itself is a standard __global__ function:

template <typename T>
__global__ void ScaleKernel(T* out, const T* in, T factor, int64_t n) {
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < n) {
    out[i] = in[i] * factor;
  }
}

The following subsections break down each step.

Input Validation#

Kernel functions should validate inputs early and fail with clear error messages. A common pattern is to define reusable CHECK_* macros on top of TVM_FFI_CHECK (see Exception Handling):

// --- Reusable validation macros ---
#define CHECK_CUDA(x) \
  TVM_FFI_CHECK((x).device().device_type == kDLCUDA, ValueError) << #x " must be a CUDA tensor"
#define CHECK_CONTIGUOUS(x) \
  TVM_FFI_CHECK((x).IsContiguous(), ValueError) << #x " must be contiguous"
#define CHECK_INPUT(x)   \
  do {                   \
    CHECK_CUDA(x);       \
    CHECK_CONTIGUOUS(x); \
  } while (0)
#define CHECK_DIM(d, x) \
  TVM_FFI_CHECK((x).ndim() == (d), ValueError) << #x " must be a " #d "D tensor"
#define CHECK_DEVICE(a, b)                                                          \
  do {                                                                              \
    TVM_FFI_CHECK((a).device().device_type == (b).device().device_type, ValueError) \
        << #a " and " #b " must be on the same device type";                        \
    TVM_FFI_CHECK((a).device().device_id == (b).device().device_id, ValueError)     \
        << #a " and " #b " must be on the same device";                             \
  } while (0)

For user-facing errors (bad arguments, unsupported dtypes, shape mismatches), use TVM_FFI_THROW or TVM_FFI_CHECK with a specific error kind so that callers receive an actionable message:

TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
TVM_FFI_CHECK(input.numel() > 0, ValueError) << "input must be non-empty";
TVM_FFI_CHECK(input.numel() == output.numel(), ValueError) << "size mismatch";

For internal invariants that indicate bugs in the kernel itself, use TVM_FFI_ICHECK:

TVM_FFI_ICHECK_GE(n, 0) << "element count must be non-negative";

Device Guard and Stream#

Before launching a CUDA kernel, two things must happen:

  1. Set the CUDA device to match the tensor’s device. tvm::ffi::CUDADeviceGuard is an RAII guard that calls cudaSetDevice on construction and restores the original device on destruction.

  2. Acquire the stream from the host framework via TVMFFIEnvGetStream(). When Python code calls a kernel with PyTorch tensors, TVM-FFI automatically captures PyTorch’s current stream for the tensor’s device.

A small helper keeps this concise:

// --- Stream helper ---
inline cudaStream_t get_cuda_stream(DLDevice device) {
  return static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
}

Every kernel function then follows the same two-line pattern:

ffi::CUDADeviceGuard guard(input.device().device_id);
cudaStream_t stream = get_cuda_stream(input.device());

See Tensor and DLPack for details on stream handling and automatic stream context updates.

Dtype Dispatch#

Kernels typically support multiple dtypes. Dispatch on DLDataType at runtime while instantiating templates at compile time:

constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};
constexpr DLDataType dl_float16 = DLDataType{kDLFloat, 16, 1};

if (input.dtype() == dl_float32) {
  ScaleKernel<<<blocks, threads, 0, stream>>>(
      static_cast<float*>(output.data_ptr()), ...);
} else if (input.dtype() == dl_float16) {
  ScaleKernel<<<blocks, threads, 0, stream>>>(
      static_cast<half*>(output.data_ptr()), ...);
} else {
  TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
}

For libraries that support many dtypes, define dispatch macros (see FlashInfer’s tvm_ffi_utils.h for a production example).

Export and Load#

Export and Build#

Export. Use TVM_FFI_DLL_EXPORT_TYPED_FUNC to create a C symbol that follows the TVM-FFI calling convention:

TVM_FFI_DLL_EXPORT_TYPED_FUNC(scale, Scale);

This creates a symbol __tvm_ffi_scale in the shared library.

Build. Compile the kernel into a shared library using GCC/NVCC or CMake (see C++ Tooling for full details):

nvcc -shared -O3 scale_kernel.cu -o build/scale_kernel.so \
    -Xcompiler -fPIC,-fvisibility=hidden \
    $(tvm-ffi-config --cxxflags) \
    $(tvm-ffi-config --ldflags) \
    $(tvm-ffi-config --libs)

Optional arguments. Wrap any argument type with tvm::ffi::Optional to accept None from the Python side:

void MyKernel(TensorView output, TensorView input,
              Optional<TensorView> bias, Optional<double> scale) {
  if (bias.has_value()) {
    // use bias.value().data_ptr()
  }
  double s = scale.value_or(1.0);
}
mod.my_kernel(y, x, None, None)         # no bias, default scale
mod.my_kernel(y, x, bias_tensor, 2.0)   # with bias and scale

Load from Python#

Use tvm_ffi.load_module() to load the library and call its functions. PyTorch tensors (and other framework tensors) are automatically converted to TensorView at the ABI boundary:

import torch
import tvm_ffi

# Load the compiled shared library
mod = tvm_ffi.load_module("build/scale_kernel.so")

# Pre-allocate input and output tensors in PyTorch
x = torch.randn(1024, device="cuda", dtype=torch.float32)
y = torch.empty_like(x)

# Call the kernel — PyTorch tensors are auto-converted to TensorView
mod.scale(y, x, 2.0)

assert torch.allclose(y, x * 2.0)

See Quick Start for examples with JAX, PaddlePaddle, NumPy, CuPy, Rust, and pure C++.

Tensor Handling#

TensorView vs Tensor#

TVM-FFI provides two tensor types (see Tensor and DLPack for full details):

TensorView (non-owning)

A lightweight view of an existing tensor. Use this for kernel parameters. It adds no reference count overhead and works with all framework tensors.

Tensor (owning)

A reference-counted tensor that manages its own lifetime. Use this only when you need to allocate and return a tensor from C++.

Important

Prefer TensorView in kernel signatures. It is more lightweight, supports more use cases (including XLA buffers that only provide views), and avoids unnecessary reference counting.

Tensor Metadata#

Both TensorView and Tensor expose identical metadata accessors. These are the methods kernel code uses most: validating inputs, computing launch parameters, and accessing data pointers.

Shape and elements. ndim() returns the number of dimensions, shape() returns the full shape as a ShapeView (a lightweight span-like view of int64_t), and size() returns the size of a single dimension (supports negative indexing, e.g. size(-1) for the last dimension). numel() returns the total element count — use it for computing grid dimensions:

int64_t n = input.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;

Dtype. dtype() returns a DLDataType with three fields: code (e.g. kDLFloat, kDLBfloat), bits (e.g. 16, 32), and lanes (almost always 1). Compare it against predefined constants to dispatch on dtype:

constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};
if (input.dtype() == dl_float32) { ... }

Device. device() returns a DLDevice with device_type (e.g. kDLCUDA) and device_id. Use these for validation and to set the device guard:

TVM_FFI_ICHECK_EQ(input.device().device_type, kDLCUDA);
ffi::CUDADeviceGuard guard(input.device().device_id);

Data pointer. data_ptr() returns void*; cast it to the appropriate typed pointer before passing it to a kernel:

auto* out = static_cast<float*>(output.data_ptr());
auto* in  = static_cast<float*>(input.data_ptr());

Strides and contiguity. strides() returns the stride array as a ShapeView, and stride() returns a single dimension’s stride. IsContiguous() checks whether the tensor is contiguous in memory. Most kernels require contiguous inputs — the CHECK_CONTIGUOUS macro shown above enforces this at the top of each function.

Tip

The API is designed to be familiar to PyTorch developers. dim(), sizes(), size(i), stride(i), and is_contiguous() are all available as aliases of their TVM-FFI counterparts. See Tensor and DLPack for the full API reference.

Tensor Allocation#

Always pre-allocate output tensors on the Python side and pass them into the kernel as TensorView parameters. Allocating tensors inside a kernel function is almost never the right choice:

  • it causes memory fragmentation from repeated small allocations,

  • it breaks CUDA graph capture, which requires deterministic memory addresses, and

  • it bypasses the framework’s allocator (caching pools, device placement, memory planning).

The pre-allocation pattern is straightforward:

# Python: pre-allocate output
y = torch.empty_like(x)
mod.scale(y, x, 2.0)
// C++: kernel writes into pre-allocated output
void Scale(TensorView output, TensorView input, double factor);

If C++-side allocation is truly unavoidable — for example, when the output shape is data-dependent and cannot be determined before the kernel runs — use tvm::ffi::Tensor::FromEnvAlloc() to at least reuse the host framework’s allocator (e.g., torch.empty under PyTorch):

// --- Tensor allocation helper ---
inline ffi::Tensor alloc_tensor(const ffi::Shape& shape, DLDataType dtype, DLDevice device) {
  return ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, shape, dtype, device);
}

For custom allocators (e.g., cudaMalloc/cudaFree), use tvm::ffi::Tensor::FromNDAlloc(). Note that the kernel library must outlive any tensors allocated this way, since the custom deleter lives in the library. See Tensor and DLPack for details.

Further Reading#