Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Export and Load Relax Executables
This tutorial walks through exporting a compiled Relax module to a shared
object, loading it back into the TVM runtime, and running the result either
interactively or from a standalone script. This tutorial demonstrates how
to turn Relax (or imported PyTorch / ONNX) programs into deployable artifacts
using tvm.relax
APIs.
Note
This tutorial uses PyTorch as the source format, but the export/load workflow
is the same for ONNX models. For ONNX, use from_onnx(model, keep_params_in_input=True)
instead of from_exported_program()
, then follow the same steps for building,
exporting, and loading.
Introduction
TVM builds Relax programs into tvm.runtime.Executable
objects. These
contain VM bytecode, compiled kernels, and constants. By exporting the
executable with export_library()
, you obtain a shared library (for
example .so
on Linux) that can be shipped to another machine, uploaded
via RPC, or loaded back later with the TVM runtime. This tutorial shows the
exact steps end-to-end and explains what files are produced along the way.
import os
from pathlib import Path
try:
import torch
from torch.export import export
except ImportError: # pragma: no cover
torch = None # type: ignore
Prepare a Torch MLP and Convert to Relax
We start with a small PyTorch MLP so the example remains lightweight. The
model is exported to a torch.export.ExportedProgram
and then
translated into a Relax IRModule
.
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
# Check dependencies first
IS_IN_CI = os.getenv("CI", "").lower() == "true"
HAS_TORCH = torch is not None
RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI
if HAS_TORCH:
class TorchMLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(28 * 28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10),
)
def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
return self.net(data)
else: # pragma: no cover
TorchMLP = None # type: ignore[misc, assignment]
if not RUN_EXAMPLE:
print("Skip model conversion because PyTorch is unavailable or we are in CI.")
if RUN_EXAMPLE:
torch_model = TorchMLP().eval()
example_args = (torch.randn(1, 1, 28, 28, dtype=torch.float32),)
with torch.no_grad():
exported_program = export(torch_model, example_args)
mod = from_exported_program(exported_program, keep_params_as_input=True)
# Separate model parameters so they can be bound later (or stored on disk).
mod, params = relax.frontend.detach_params(mod)
print("Imported Relax module:")
mod.show()
Skip model conversion because PyTorch is unavailable or we are in CI.
Build and Export with export_library
We build for llvm
to generate CPU code and then export the resulting
executable. Passing workspace_dir
keeps the intermediate packaging files,
which is useful to inspect what was produced.
TARGET = tvm.target.Target("llvm")
ARTIFACT_DIR = Path("relax_export_artifacts")
ARTIFACT_DIR.mkdir(exist_ok=True)
if RUN_EXAMPLE:
# Apply the default Relax compilation pipeline before building.
pipeline = relax.get_pipeline()
with TARGET:
built_mod = pipeline(mod)
# Build without params - we'll pass them at runtime
executable = relax.build(built_mod, target=TARGET)
library_path = ARTIFACT_DIR / "mlp_cpu.so"
executable.export_library(str(library_path), workspace_dir=str(ARTIFACT_DIR))
print(f"Exported runtime library to: {library_path}")
# The workspace directory now contains the shared object and supporting files.
produced_files = sorted(p.name for p in ARTIFACT_DIR.iterdir())
print("Artifacts saved:")
for name in produced_files:
print(f" - {name}")
# Generated files:
# - ``mlp_cpu.so``: The main deployable shared library containing VM bytecode,
# compiled kernels, and constants. Note: Since parameters are passed at runtime,
# you will also need to save a separate parameters file (see next section).
# - Intermediate object files (``devc.o``, ``lib0.o``, etc.) are kept in the
# workspace for inspection but are not required for deployment.
#
# Note: Additional files like ``*.params``, ``*.metadata.json``, or ``*.imports``
# may appear in specific configurations but are typically embedded into the
# shared library or only generated when needed.
Load the Exported Library and Run It
Once the shared object is produced, we can reload it back into the TVM runtime on any machine with a compatible instruction set. The Relax VM consumes the runtime module directly.
if RUN_EXAMPLE:
loaded_rt_mod = tvm.runtime.load_module(str(library_path))
dev = tvm.cpu(0)
vm = relax.VirtualMachine(loaded_rt_mod, dev)
# Prepare input data
input_tensor = torch.randn(1, 1, 28, 28, dtype=torch.float32)
vm_input = tvm.runtime.tensor(input_tensor.numpy(), dev)
# Prepare parameters (allocate on target device)
vm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
# Run inference: pass input data followed by all parameters
tvm_output = vm["main"](vm_input, *vm_params)
# TVM returns Array objects for tuple outputs, access via indexing.
# For models imported from PyTorch, outputs are typically tuples (even for single outputs).
# For ONNX models, outputs may be a single Tensor directly.
result_tensor = tvm_output[0] if isinstance(tvm_output, (tuple, list)) else tvm_output
print("VM output shape:", result_tensor.shape)
print("VM output type:", type(tvm_output), "->", type(result_tensor))
# You can still inspect the executable after reloading.
print("Executable stats:\n", loaded_rt_mod["stats"]())
Save Parameters for Deployment
Since parameters are passed at runtime (not embedded in the .so
), we must
save them separately for deployment. This is a required step to use the model
on other machines or in standalone scripts.
import numpy as np
if RUN_EXAMPLE:
# Save parameters to disk
params_path = ARTIFACT_DIR / "model_params.npz"
param_arrays = {f"p_{i}": p.numpy() for i, p in enumerate(params["main"])}
np.savez(str(params_path), **param_arrays)
print(f"Saved parameters to: {params_path}")
# Note: Alternatively, you can embed parameters directly into the ``.so`` to
# create a single-file deployment. Use ``keep_params_as_input=False`` when
# importing from PyTorch:
#
# .. code-block:: python
#
# mod = from_exported_program(exported_program, keep_params_as_input=False)
# # Parameters are now embedded as constants in the module
# executable = relax.build(built_mod, target=TARGET)
# # Runtime: vm["main"](input) # No need to pass params!
#
# This creates a single-file deployment (only the ``.so`` is needed), but you
# lose the flexibility to swap parameters without recompiling. For most
# production workflows, separating code and parameters (as shown above) is
# preferred for flexibility.
Loading and Running the Exported Model
To use the exported model on another machine or in a standalone script, you need
to load both the .so
library and the parameters file. Here’s a complete example
of how to reload and run the model. Save this as run_mlp.py
:
To make it executable from the command line:
chmod +x run_mlp.py
./run_mlp.py # Run it like a regular program
Complete script:
#!/usr/bin/env python3
import numpy as np
import tvm
from tvm import relax
# Step 1: Load the compiled library
lib = tvm.runtime.load_module("relax_export_artifacts/mlp_cpu.so")
# Step 2: Create Virtual Machine
device = tvm.cpu(0)
vm = relax.VirtualMachine(lib, device)
# Step 3: Load parameters from the .npz file
params_npz = np.load("relax_export_artifacts/model_params.npz")
params = [tvm.runtime.tensor(params_npz[f"p_{i}"], device)
for i in range(len(params_npz))]
# Step 4: Prepare input data
data = np.random.randn(1, 1, 28, 28).astype("float32")
input_tensor = tvm.runtime.tensor(data, device)
# Step 5: Run inference (pass input followed by all parameters)
output = vm["main"](input_tensor, *params)
# Step 6: Extract result (output may be tuple or single Tensor)
# PyTorch models typically return tuples, ONNX models may return a single Tensor
result = output[0] if isinstance(output, (tuple, list)) else output
print("Prediction shape:", result.shape)
print("Predicted class:", np.argmax(result.numpy()))
Running on GPU: To run on GPU instead of CPU, make the following changes:
Compile for GPU (earlier in the tutorial, around line 112): .. code-block:: python
TARGET = tvm.target.Target(“cuda”) # Change from “llvm” to “cuda”
Use GPU device in the script: .. code-block:: python
device = tvm.cuda(0) # Use CUDA device instead of CPU vm = relax.VirtualMachine(lib, device)
# Load parameters to GPU params = [tvm.runtime.tensor(params_npz[f”p_{i}”], device) # Note: device parameter
for i in range(len(params_npz))]
# Prepare input on GPU input_tensor = tvm.runtime.tensor(data, device) # Note: device parameter
The rest of the script remains the same. All tensors (parameters and inputs) must be allocated on the same device (GPU) as the compiled model.
Deployment Checklist: When moving to another host (via RPC or SCP), you must copy both files:
mlp_cpu.so
(ormlp_cuda.so
for GPU) - The compiled model code
model_params.npz
- The model parameters (serialized as NumPy arrays)
The remote machine needs both files in the same directory. The script above
assumes they are in relax_export_artifacts/
relative to the script location.
Adjust the paths as needed for your deployment. For GPU deployment, ensure the
target machine has compatible CUDA drivers and the model was compiled for the
same GPU architecture.
Deploying to Remote Devices
To deploy the exported model to a remote ARM Linux device (e.g., Raspberry Pi), you can use TVM’s RPC mechanism to cross-compile, upload, and run the model remotely. This workflow is useful when:
The target device has limited resources for compilation
You want to fine-tune performance by running on the actual hardware
You need to deploy to embedded devices
See cross_compilation_and_rpc for a comprehensive guide on:
Setting up TVM runtime on the remote device
Starting an RPC server on the device
Cross-compiling for ARM targets (e.g.,
llvm -mtriple=aarch64-linux-gnu
)Uploading exported libraries via RPC
Running inference remotely
Quick example for ARM deployment workflow:
import tvm.rpc as rpc
from tvm import relax
# Step 1: Cross-compile for ARM target (on local machine)
TARGET = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu")
executable = relax.build(built_mod, target=TARGET)
executable.export_library("mlp_arm.so")
# Step 2: Connect to remote device RPC server
remote = rpc.connect("192.168.1.100", 9090) # Device IP and RPC port
# Step 3: Upload the compiled library and parameters
remote.upload("mlp_arm.so")
remote.upload("model_params.npz")
# Step 4: Load and run on remote device
lib = remote.load_module("mlp_arm.so")
vm = relax.VirtualMachine(lib, remote.cpu())
# ... prepare input and params, then run inference
The key difference is using an ARM target triple during compilation and uploading files via RPC instead of copying them directly.
FAQ
- Can I run the ``.so`` as a standalone executable (like ``./mlp_cpu.so``)?
No. The
.so
file is a shared library, not a standalone executable binary. You cannot run it directly from the terminal. It must be loaded through a TVM runtime program (as shown in the “Loading and Running” section above). The.so
bundles VM bytecode and compiled kernels, but still requires the TVM runtime to execute.- Which devices can run the exported library?
The target must match the ISA you compiled for (
llvm
in this example). As long as the target triple, runtime ABI, and available devices line up, you can move the artifact between machines. For heterogeneous builds (CPU plus GPU), ship the extra device libraries as well.- What about the ``.params`` and ``metadata.json`` files?
These auxiliary files are only generated in specific configurations. In this tutorial, since we pass parameters at runtime, they are not generated. When they do appear, they may be kept alongside the
.so
for inspection, but the essential content is typically embedded in the shared object itself, so deploying the.so
alone is usually sufficient.