Quick Start#

Note

All the code in this tutorial can be found under examples/quickstart in the repository.

This guide walks through shipping a minimal add_one function that computes y = x + 1 in C++ and CUDA. TVM-FFI’s Open ABI and FFI make it possible to ship one library for multiple frameworks and languages. We can build a single shared library that works across:

  • ML frameworks, e.g. PyTorch, JAX, NumPy, CuPy, etc., and

  • Languages, e.g. C++, Python, Rust, etc.,

  • Python ABI versions, e.g. ship one wheel to support all Python versions, including free-threaded ones.

Prerequisite

  • Python: 3.9 or newer

  • Compiler: C++17-capable toolchain (GCC/Clang/MSVC)

  • Optional ML frameworks for testing: NumPy, PyTorch, JAX, CuPy

  • CUDA: Any modern version (if you want to try the CUDA part)

  • TVM-FFI installed via

    pip install --reinstall --upgrade apache-tvm-ffi
    

Write a Simple add_one#

Source Code#

Suppose we implement a C++ function AddOne that performs elementwise y = x + 1 for a 1-D float32 vector. The source code (C++, CUDA) is:

// File: compile/add_one_cpu.cc
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>

namespace tvm_ffi_example_cpu {

/*! \brief Perform vector add one: y = x + 1 (1-D float32) */
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  int64_t n = x.size(0);
  float* x_data = static_cast<float*>(x.data_ptr());
  float* y_data = static_cast<float*>(y.data_ptr());
  for (int64_t i = 0; i < n; ++i) {
    y_data[i] = x_data[i] + 1;
  }
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example_cpu::AddOne);
}  // namespace tvm_ffi_example_cpu

The macro TVM_FFI_DLL_EXPORT_TYPED_FUNC exports the C++ function AddOne as a TVM FFI compatible symbol with the name __tvm_ffi_add_one_cpu/cuda in the resulting library.

The class tvm::ffi::TensorView allows zero-copy interop with tensors from different ML frameworks:

  • NumPy, CuPy,

  • PyTorch, JAX, or

  • any array type that supports the standard DLPack protocol.

Finally, TVMFFIEnvGetStream() can be used in the CUDA code to launch a kernel on the caller’s stream.

Compile with TVM-FFI#

Raw command. We can use the following minimal commands to compile the source code:

g++ -shared -O3 compile/add_one_cpu.cc  \
    -fPIC -fvisibility=hidden           \
    $(tvm-ffi-config --cxxflags)        \
    $(tvm-ffi-config --ldflags)         \
    $(tvm-ffi-config --libs)            \
    -o $BUILD_DIR/add_one_cpu.so

This step produces a shared library add_one_cpu.so and add_one_cuda.so that can be used across languages and frameworks.

Hint

For a single-file C++/CUDA project, a convenient method tvm_ffi.cpp.load_inline() is provided to minimize boilerplate code in compilation, linking, and loading.

CMake. CMake is the preferred approach for building across platforms. TVM-FFI natively integrates with CMake via find_package as demonstrated below:

# Run `tvm-ffi-config --cmakedir` to set `tvm_ffi_DIR`
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
find_package(tvm_ffi CONFIG REQUIRED)

# Link C++ target to `tvm_ffi_header` and `tvm_ffi_shared`
add_library(add_one_cpu SHARED compile/add_one_cpu.cc)
target_link_libraries(add_one_cpu PRIVATE tvm_ffi_header)
target_link_libraries(add_one_cpu PRIVATE tvm_ffi_shared)

Artifact. The resulting add_one_cpu.so and add_one_cuda.so are minimal libraries that are agnostic to:

  • Python version/ABI. It is not compiled/linked with Python and depends only on TVM-FFI’s stable C ABI;

  • Languages, including C++, Python, Rust or any other language that can interop with C ABI;

  • ML frameworks, such as PyTorch, JAX, NumPy, CuPy, or anything with standard DLPack protocol.

Ship Across ML Frameworks#

TVM-FFI’s Python package provides tvm_ffi.load_module(), which can load either the add_one_cpu.so or add_one_cuda.so into tvm_ffi.Module.

import tvm_ffi
mod  : tvm_ffi.Module   = tvm_ffi.load_module("add_one_cpu.so")
func : tvm_ffi.Function = mod.add_one_cpu

mod.add_one_cpu retrieves a callable tvm_ffi.Function that accepts tensors from host frameworks directly. This process is done zero-copy, without any boilerplate code, under extremely low latency.

We can then use these functions in the following ways:

# File: load/load_pytorch.py
# Step 1. Load `build/add_one_cuda.so`
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")

# Step 2. Run `mod.add_one_cuda` with PyTorch
import torch
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
y = torch.empty_like(x)
mod.add_one_cuda(x, y)
print(y)

Support via nvidia/jax-tvm-ffi. This can be installed via

pip install jax-tvm-ffi

After installation, add_one_cuda can be registered as a target to JAX’s ffi_call.

# Step 1. Load `build/add_one_cuda.so`
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")

# Step 2. Register `mod.add_one_cuda` into JAX
import jax_tvm_ffi
jax_tvm_ffi.register_ffi_target("add_one", mod.add_one_cuda, platform="gpu")

# Step 3. Run `mod.add_one_cuda` with JAX
import jax
import jax.numpy as jnp
jax_device, *_ = jax.devices("gpu")
x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32, device=jax_device)
y = jax.ffi.ffi_call(
    "add_one",  # name of the registered function
    jax.ShapeDtypeStruct(x.shape, x.dtype),  # shape and dtype of the output
    vmap_method="broadcast_all",
)(x)
print(y)
# File: load/load_numpy.py
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cpu.so")

import numpy as np
x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
y = np.empty_like(x)
mod.add_one_cpu(x, y)
print(y)
# File: load/load_cupy.py
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")

import cupy as cp
x = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
y = cp.empty_like(x)
mod.add_one_cuda(x, y)
print(y)

Ship Across Languages#

TVM-FFI’s core loading mechanism is ABI stable and works across language boundaries. A single library can be loaded in every language TVM-FFI supports, without having to recompile different libraries targeting different ABIs or languages.

Python#

As shown in the previous section, tvm_ffi.load_module() loads a language- and framework-independent add_one_cpu.so or add_one_cuda.so and can be used to incorporate it into all Python array frameworks that implement the standard DLPack protocol.

C++#

TVM-FFI’s C++ API tvm::ffi::Module::LoadFromFile() loads add_one_cpu.so or add_one_cuda.so and can be used directly in C/C++ with no Python dependency.

// File: load/load_cpp.cc
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/module.h>

namespace {
namespace ffi = tvm::ffi;

/************* Main logics *************/

/*!
 * \brief Main logics of library loading and function calling.
 * \param x The input tensor.
 * \param y The output tensor.
 */
void Run(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  // Load shared library `build/add_one_cpu.so`
  ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
  // Look up `add_one_cpu` function
  ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
  // Call the function
  add_one_cpu(x, y);
}

/************* Auxiliary Logics *************/

/*!
 * \brief Allocate a 1D float32 `tvm::ffi::Tensor` on CPU from an braced initializer list.
 * \param data The input data.
 * \return The allocated Tensor.
 */
ffi::Tensor Alloc1DTensor(std::initializer_list<float> data) {
  struct CPUAllocator {
    void AllocData(DLTensor* tensor) {
      tensor->data = std::malloc(tensor->shape[0] * sizeof(float));
    }
    void FreeData(DLTensor* tensor) { std::free(tensor->data); }
  };
  DLDataType f32 = DLDataType({kDLFloat, 32, 1});
  DLDevice cpu = DLDevice({kDLCPU, 0});
  int64_t n = static_cast<int64_t>(data.size());
  ffi::Tensor x = ffi::Tensor::FromNDAlloc(CPUAllocator(), {n}, f32, cpu);
  float* x_data = static_cast<float*>(x.data_ptr());
  for (float v : data) {
    *x_data++ = v;
  }
  return x;
}

}  // namespace

int main() {
  ffi::Tensor x = Alloc1DTensor({1, 2, 3, 4, 5});
  ffi::Tensor y = Alloc1DTensor({0, 0, 0, 0, 0});
  Run(x, y);
  std::cout << "[ ";
  const float* y_data = static_cast<const float*>(y.data_ptr());
  for (int i = 0; i < 5; ++i) {
    std::cout << y_data[i] << " ";
  }
  std::cout << "]" << std::endl;
  return 0;
}

Compile and run it with:

g++ -fvisibility=hidden -O3                 \
    load/load_cpp.cc                        \
    $(tvm-ffi-config --cxxflags)            \
    $(tvm-ffi-config --ldflags)             \
    $(tvm-ffi-config --libs)                \
    -Wl,-rpath,$(tvm-ffi-config --libdir)   \
    -o build/load_cpp

build/load_cpp

Note

Don’t like loading shared libraries? Static linking is also supported.

In such cases, we can use tvm::ffi::Function::FromExternC() to create a tvm::ffi::Function from the exported symbol, or directly use tvm::ffi::Function::InvokeExternC() to invoke the function.

This feature can be useful on iOS, or when the exported module is generated by another DSL compiler matching the ABI.

// Linked with `add_one_cpu.o` or `add_one_cuda.o`
#include <tvm/ffi/function.h>
#include <tvm/ffi/container/tensor.h>

// declare reference to the exported symbol
extern "C" int __tvm_ffi_add_one_cpu(void*, const TVMFFIAny*, int32_t, TVMFFIAny*);

namespace ffi = tvm::ffi;

int bundle_add_one(ffi::TensorView x, ffi::TensorView y) {
  void* closure_handle = nullptr;
  ffi::Function::InvokeExternC(closure_handle, __tvm_ffi_add_one_cpu, x, y);
  return 0;
}

Rust#

TVM-FFI’s Rust API tvm_ffi::Module::load_from_file loads add_one_cpu.so or add_one_cuda.so and then retrieves a function add_one_cpu or add_one_cuda from it. This procedure is identical to those in C++ and Python:

fn run_add_one(x: &Tensor, y: &Tensor) -> Result<()> {
    let module = tvm_ffi::Module::load_from_file("add_one_cpu.so")?;
    let func = module.get_function("add_one_cpu")?;
    let typed_fn = into_typed_fn!(func, Fn(&Tensor, &Tensor) -> Result<()>);
    typed_fn(x, y)?;
    Ok(())
}

Hint

We can also use the Rust API to target the TVM FFI ABI. This means we can use Rust to write the function implementation and export to Python/C++ in the same fashion.

Troubleshooting#

  • OSError: cannot open shared object file: Add an rpath (Linux/macOS) or ensure the DLL is on PATH (Windows). Example run-path: -Wl,-rpath,`tvm-ffi-config --libdir`.

  • undefined symbol: __tvm_ffi_add_one_cpu: Ensure you used TVM_FFI_DLL_EXPORT_TYPED_FUNC and compiled with default symbol visibility (-fvisibility=hidden is fine; the macro ensures export).

  • CUDA error: invalid device function: Rebuild with the correct -arch=sm_XX for your GPU, or include multiple -gencode entries.