Quick Start#
Note
All the code in this tutorial can be found under examples/quickstart in the repository.
This guide walks through shipping a minimal add_one
function that computes
y = x + 1
in C++ and CUDA.
TVM-FFI’s Open ABI and FFI make it possible to ship one library for multiple frameworks and languages.
We can build a single shared library that works across:
ML frameworks, e.g. PyTorch, JAX, NumPy, CuPy, etc., and
Languages, e.g. C++, Python, Rust, etc.,
Python ABI versions, e.g. ship one wheel to support all Python versions, including free-threaded ones.
Prerequisite
Python: 3.9 or newer
Compiler: C++17-capable toolchain (GCC/Clang/MSVC)
Optional ML frameworks for testing: NumPy, PyTorch, JAX, CuPy
CUDA: Any modern version (if you want to try the CUDA part)
TVM-FFI installed via
pip install --reinstall --upgrade apache-tvm-ffi
Write a Simple add_one
#
Source Code#
Suppose we implement a C++ function AddOne
that performs elementwise y = x + 1
for a 1-D float32
vector. The source code (C++, CUDA) is:
// File: compile/add_one_cpu.cc
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>
namespace tvm_ffi_example_cpu {
/*! \brief Perform vector add one: y = x + 1 (1-D float32) */
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
int64_t n = x.size(0);
float* x_data = static_cast<float*>(x.data_ptr());
float* y_data = static_cast<float*>(y.data_ptr());
for (int64_t i = 0; i < n; ++i) {
y_data[i] = x_data[i] + 1;
}
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example_cpu::AddOne);
} // namespace tvm_ffi_example_cpu
// File: compile/add_one_cuda.cu
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
namespace tvm_ffi_example_cuda {
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1;
}
}
void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
int64_t n = x.size(0);
float* x_data = static_cast<float*>(x.data_ptr());
float* y_data = static_cast<float*>(y.data_ptr());
int64_t threads = 256;
int64_t blocks = (n + threads - 1) / threads;
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(x.device().device_type, x.device().device_id));
AddOneKernel<<<blocks, threads, 0, stream>>>(x_data, y_data, n);
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example_cuda::AddOne);
} // namespace tvm_ffi_example_cuda
The macro TVM_FFI_DLL_EXPORT_TYPED_FUNC
exports the C++ function AddOne
as a TVM FFI compatible symbol with the name __tvm_ffi_add_one_cpu/cuda
in the resulting library.
The class tvm::ffi::TensorView
allows zero-copy interop with tensors from different ML frameworks:
NumPy, CuPy,
PyTorch, JAX, or
any array type that supports the standard DLPack protocol.
Finally, TVMFFIEnvGetStream()
can be used in the CUDA code to launch a kernel on the caller’s stream.
Compile with TVM-FFI#
Raw command. We can use the following minimal commands to compile the source code:
g++ -shared -O3 compile/add_one_cpu.cc \
-fPIC -fvisibility=hidden \
$(tvm-ffi-config --cxxflags) \
$(tvm-ffi-config --ldflags) \
$(tvm-ffi-config --libs) \
-o $BUILD_DIR/add_one_cpu.so
nvcc -shared -O3 compile/add_one_cuda.cu \
-Xcompiler -fPIC,-fvisibility=hidden \
$(tvm-ffi-config --cxxflags) \
$(tvm-ffi-config --ldflags) \
$(tvm-ffi-config --libs) \
-o $BUILD_DIR/add_one_cuda.so
This step produces a shared library add_one_cpu.so
and add_one_cuda.so
that can be used across languages and frameworks.
Hint
For a single-file C++/CUDA project, a convenient method tvm_ffi.cpp.load_inline()
is provided to minimize boilerplate code in compilation, linking, and loading.
CMake. CMake is the preferred approach for building across platforms.
TVM-FFI natively integrates with CMake via find_package
as demonstrated below:
# Run `tvm-ffi-config --cmakedir` to set `tvm_ffi_DIR`
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
find_package(tvm_ffi CONFIG REQUIRED)
# Link C++ target to `tvm_ffi_header` and `tvm_ffi_shared`
add_library(add_one_cpu SHARED compile/add_one_cpu.cc)
target_link_libraries(add_one_cpu PRIVATE tvm_ffi_header)
target_link_libraries(add_one_cpu PRIVATE tvm_ffi_shared)
enable_language(CUDA)
# Run `tvm-ffi-config --cmakedir` to set `tvm_ffi_DIR`
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
find_package(tvm_ffi CONFIG REQUIRED)
# Link CUDA target to `tvm_ffi_header` and `tvm_ffi_shared`
add_library(add_one_cuda SHARED compile/add_one_cuda.cu)
target_link_libraries(add_one_cuda PRIVATE tvm_ffi_header)
target_link_libraries(add_one_cuda PRIVATE tvm_ffi_shared)
Artifact. The resulting add_one_cpu.so
and add_one_cuda.so
are minimal libraries that are agnostic to:
Python version/ABI. It is not compiled/linked with Python and depends only on TVM-FFI’s stable C ABI;
Languages, including C++, Python, Rust or any other language that can interop with C ABI;
ML frameworks, such as PyTorch, JAX, NumPy, CuPy, or anything with standard DLPack protocol.
Ship Across ML Frameworks#
TVM-FFI’s Python package provides tvm_ffi.load_module()
, which can load either
the add_one_cpu.so
or add_one_cuda.so
into tvm_ffi.Module
.
import tvm_ffi
mod : tvm_ffi.Module = tvm_ffi.load_module("add_one_cpu.so")
func : tvm_ffi.Function = mod.add_one_cpu
mod.add_one_cpu
retrieves a callable tvm_ffi.Function
that accepts tensors from host frameworks
directly. This process is done zero-copy, without any boilerplate code, under extremely low latency.
We can then use these functions in the following ways:
# File: load/load_pytorch.py
# Step 1. Load `build/add_one_cuda.so`
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")
# Step 2. Run `mod.add_one_cuda` with PyTorch
import torch
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
y = torch.empty_like(x)
mod.add_one_cuda(x, y)
print(y)
Support via nvidia/jax-tvm-ffi. This can be installed via
pip install jax-tvm-ffi
After installation, add_one_cuda
can be registered as a target to JAX’s ffi_call
.
# Step 1. Load `build/add_one_cuda.so`
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")
# Step 2. Register `mod.add_one_cuda` into JAX
import jax_tvm_ffi
jax_tvm_ffi.register_ffi_target("add_one", mod.add_one_cuda, platform="gpu")
# Step 3. Run `mod.add_one_cuda` with JAX
import jax
import jax.numpy as jnp
jax_device, *_ = jax.devices("gpu")
x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32, device=jax_device)
y = jax.ffi.ffi_call(
"add_one", # name of the registered function
jax.ShapeDtypeStruct(x.shape, x.dtype), # shape and dtype of the output
vmap_method="broadcast_all",
)(x)
print(y)
# File: load/load_numpy.py
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cpu.so")
import numpy as np
x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
y = np.empty_like(x)
mod.add_one_cpu(x, y)
print(y)
# File: load/load_cupy.py
import tvm_ffi
mod = tvm_ffi.load_module("build/add_one_cuda.so")
import cupy as cp
x = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
y = cp.empty_like(x)
mod.add_one_cuda(x, y)
print(y)
Ship Across Languages#
TVM-FFI’s core loading mechanism is ABI stable and works across language boundaries. A single library can be loaded in every language TVM-FFI supports, without having to recompile different libraries targeting different ABIs or languages.
Python#
As shown in the previous section, tvm_ffi.load_module()
loads a language-
and framework-independent add_one_cpu.so
or add_one_cuda.so
and can be used to incorporate it into all Python
array frameworks that implement the standard DLPack protocol.
C++#
TVM-FFI’s C++ API tvm::ffi::Module::LoadFromFile()
loads add_one_cpu.so
or add_one_cuda.so
and
can be used directly in C/C++ with no Python dependency.
// File: load/load_cpp.cc
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/module.h>
namespace {
namespace ffi = tvm::ffi;
/************* Main logics *************/
/*!
* \brief Main logics of library loading and function calling.
* \param x The input tensor.
* \param y The output tensor.
*/
void Run(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Load shared library `build/add_one_cpu.so`
ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
// Look up `add_one_cpu` function
ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
// Call the function
add_one_cpu(x, y);
}
/************* Auxiliary Logics *************/
/*!
* \brief Allocate a 1D float32 `tvm::ffi::Tensor` on CPU from an braced initializer list.
* \param data The input data.
* \return The allocated Tensor.
*/
ffi::Tensor Alloc1DTensor(std::initializer_list<float> data) {
struct CPUAllocator {
void AllocData(DLTensor* tensor) {
tensor->data = std::malloc(tensor->shape[0] * sizeof(float));
}
void FreeData(DLTensor* tensor) { std::free(tensor->data); }
};
DLDataType f32 = DLDataType({kDLFloat, 32, 1});
DLDevice cpu = DLDevice({kDLCPU, 0});
int64_t n = static_cast<int64_t>(data.size());
ffi::Tensor x = ffi::Tensor::FromNDAlloc(CPUAllocator(), {n}, f32, cpu);
float* x_data = static_cast<float*>(x.data_ptr());
for (float v : data) {
*x_data++ = v;
}
return x;
}
} // namespace
int main() {
ffi::Tensor x = Alloc1DTensor({1, 2, 3, 4, 5});
ffi::Tensor y = Alloc1DTensor({0, 0, 0, 0, 0});
Run(x, y);
std::cout << "[ ";
const float* y_data = static_cast<const float*>(y.data_ptr());
for (int i = 0; i < 5; ++i) {
std::cout << y_data[i] << " ";
}
std::cout << "]" << std::endl;
return 0;
}
Compile and run it with:
g++ -fvisibility=hidden -O3 \
load/load_cpp.cc \
$(tvm-ffi-config --cxxflags) \
$(tvm-ffi-config --ldflags) \
$(tvm-ffi-config --libs) \
-Wl,-rpath,$(tvm-ffi-config --libdir) \
-o build/load_cpp
build/load_cpp
Note
Don’t like loading shared libraries? Static linking is also supported.
In such cases, we can use tvm::ffi::Function::FromExternC()
to create a
tvm::ffi::Function
from the exported symbol, or directly use
tvm::ffi::Function::InvokeExternC()
to invoke the function.
This feature can be useful on iOS, or when the exported module is generated by another DSL compiler matching the ABI.
// Linked with `add_one_cpu.o` or `add_one_cuda.o`
#include <tvm/ffi/function.h>
#include <tvm/ffi/container/tensor.h>
// declare reference to the exported symbol
extern "C" int __tvm_ffi_add_one_cpu(void*, const TVMFFIAny*, int32_t, TVMFFIAny*);
namespace ffi = tvm::ffi;
int bundle_add_one(ffi::TensorView x, ffi::TensorView y) {
void* closure_handle = nullptr;
ffi::Function::InvokeExternC(closure_handle, __tvm_ffi_add_one_cpu, x, y);
return 0;
}
Rust#
TVM-FFI’s Rust API tvm_ffi::Module::load_from_file
loads add_one_cpu.so
or add_one_cuda.so
and
then retrieves a function add_one_cpu
or add_one_cuda
from it.
This procedure is identical to those in C++ and Python:
fn run_add_one(x: &Tensor, y: &Tensor) -> Result<()> {
let module = tvm_ffi::Module::load_from_file("add_one_cpu.so")?;
let func = module.get_function("add_one_cpu")?;
let typed_fn = into_typed_fn!(func, Fn(&Tensor, &Tensor) -> Result<()>);
typed_fn(x, y)?;
Ok(())
}
Hint
We can also use the Rust API to target the TVM FFI ABI. This means we can use Rust to write the function implementation and export to Python/C++ in the same fashion.
Troubleshooting#
OSError: cannot open shared object file
: Add an rpath (Linux/macOS) or ensure the DLL is onPATH
(Windows). Example run-path:-Wl,-rpath,`tvm-ffi-config --libdir`
.undefined symbol: __tvm_ffi_add_one_cpu
: Ensure you usedTVM_FFI_DLL_EXPORT_TYPED_FUNC
and compiled with default symbol visibility (-fvisibility=hidden
is fine; the macro ensures export).CUDA error: invalid device function
: Rebuild with the correct-arch=sm_XX
for your GPU, or include multiple-gencode
entries.