Note
You can click here to run the Jupyter notebook locally.
Mix Python/PyTorch with TVM Using BasePyModule
In a typical TVM workflow, you write an IRModule, compile it, and load the compiled artifact
into a VirtualMachine to run. This means you cannot test or debug anything until the entire
module compiles successfully. If a single op is unsupported, the whole pipeline is blocked.
BasePyModule solves this by letting Python functions, TIR kernels, and Relax functions coexist
in one module. TIR and Relax functions are JIT-compiled on instantiation, Python functions run
as-is, and tensors move between TVM and PyTorch via zero-copy DLPack. This enables:
Incremental development: get a model running with Python fallbacks first, then replace them with TVM ops one by one.
Easy debugging: insert
printin Python functions to inspect intermediate tensors — no need to compile the whole module first.Verification at any compilation stage: convert Relax IR back to PyTorch to check numerical correctness before and after optimization passes.
Hybrid execution: let the compiled VM call back into Python for ops that are hard to express in TIR or Relax.
This tutorial walks through the full workflow step by step.
Preparation
import os
try:
import torch
import torch.nn.functional as F
except ImportError:
torch = None
import tvm
from tvm import relax
from tvm.relax.base_py_module import BasePyModule
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
IS_IN_CI = os.getenv("CI", "").lower() == "true"
HAS_TORCH = torch is not None
RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI
Step 1: Your First Hybrid Module
The core idea: decorate a class with @I.ir_module, inherit from BasePyModule, and use
three decorators for three kinds of functions:
@T.prim_func— low-level TIR kernel (JIT-compiled on instantiation)@R.function— high-level Relax graph (JIT-compiled on instantiation)@I.pyfunc— plain Python (runs as-is, can use any Python library)
call_tir bridges Python and TIR: it converts PyTorch tensors to TVM NDArrays via DLPack
(zero-copy), allocates the output buffer, calls the compiled kernel, and converts back.
if RUN_EXAMPLE:
@I.ir_module
class MyFirstModule(BasePyModule):
@T.prim_func
def add_tir(
A: T.Buffer((4,), "float32"),
B: T.Buffer((4,), "float32"),
C: T.Buffer((4,), "float32"),
):
for i in range(4):
C[i] = A[i] + B[i]
@I.pyfunc
def forward(self, x, y):
"""Takes PyTorch tensors, calls TIR, returns PyTorch tensors."""
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(
self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32")
)
return self._convert_tvm_to_pytorch(result)
# TIR functions are JIT-compiled at instantiation
mod = MyFirstModule(device=tvm.cpu(0))
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([10.0, 20.0, 30.0, 40.0])
result = mod.forward(x, y)
print("forward(x, y) =", result)
assert torch.allclose(result, x + y)
# show() prints TVMScript including Python functions (shown as ExternFunc)
mod.show()
# list_functions() shows what is available in the module
print("Available functions:", mod.list_functions())
Step 2: Debugging — The Main Selling Point
Traditional ML compilers treat computation graphs as monolithic blobs. You cannot inspect
intermediate tensor values without compiling the entire module. With @I.pyfunc, debugging
is as simple as adding a print statement. You can also make quick edits and re-run
immediately — no recompilation needed.
if RUN_EXAMPLE:
@I.ir_module
class DebugModule(BasePyModule):
@T.prim_func
def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
n = T.int32()
A = T.match_buffer(var_A, (n, 4), "float32")
B = T.match_buffer(var_B, (4, 3), "float32")
C = T.match_buffer(var_C, (n, 3), "float32")
for i, j, k in T.grid(n, 3, 4):
with T.sblock("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@I.pyfunc
def forward(self, x, weights):
# Inspect input
print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}")
# Run TIR matmul
x_tvm = self._convert_pytorch_to_tvm(x)
w_tvm = self._convert_pytorch_to_tvm(weights)
out = self.call_tir(
self.matmul_tir,
[x_tvm, w_tvm],
out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
)
logits = self._convert_tvm_to_pytorch(out)
# Inspect intermediate value — impossible with a compiled-only workflow
print(f" [DEBUG] logits shape: {logits.shape}, "
f"min: {logits.min():.4f}, max: {logits.max():.4f}")
result = F.softmax(logits, dim=-1)
# Verify output
print(f" [DEBUG] probs sum: {result.sum(dim=-1)}")
return result
mod = DebugModule(device=tvm.cpu(0))
print("Running with debug prints:")
probs = mod.forward(torch.randn(2, 4), torch.randn(4, 3))
assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5)
This is the key benefit: “debugging is as simple as inserting a print statement. Users can also make quick, manual edits to Python functions and immediately observe the results.” No compilation cycle, no VM loading — just Python.
Step 3: A Realistic Pipeline — Python, TIR, and Packed Functions
Real models combine many kinds of operations. This step builds a mini inference pipeline using three different calling conventions:
call_tir: call a compiled TIR kernelcall_dps_packed: call a TVM packed function (e.g., a third-party library binding)Direct Python: call any PyTorch function
call_dps_packed is useful for calling functions registered via tvm.register_global_func
— for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions.
if RUN_EXAMPLE:
# Register a packed function (simulating an external library binding)
@tvm.register_global_func("my_bias_add", override=True)
def my_bias_add(x, bias, out):
"""Packed function: adds bias to each row of x."""
import numpy as np
x_np = x.numpy()
b_np = bias.numpy()
out_np = x_np + b_np
out[:] = out_np
@I.ir_module
class PipelineModule(BasePyModule):
@T.prim_func
def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
A = T.match_buffer(var_A, (2, 4), "float32")
B = T.match_buffer(var_B, (4, 3), "float32")
C = T.match_buffer(var_C, (2, 3), "float32")
for i, j, k in T.grid(2, 3, 4):
with T.sblock("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@I.pyfunc
def forward(self, x, weights, bias):
# 1. TIR matmul
x_tvm = self._convert_pytorch_to_tvm(x)
w_tvm = self._convert_pytorch_to_tvm(weights)
h = self.call_tir(
self.matmul_tir, [x_tvm, w_tvm],
out_sinfo=R.Tensor((2, 3), "float32"),
)
h_pt = self._convert_tvm_to_pytorch(h)
# 2. Packed function for bias add (simulating an external library)
h_biased = self.call_dps_packed(
"my_bias_add", [h_pt, bias],
out_sinfo=R.Tensor((2, 3), "float32"),
)
# 3. Python/PyTorch activation
return F.relu(h_biased)
mod = PipelineModule(device=tvm.cpu(0))
x = torch.randn(2, 4)
w = torch.randn(4, 3)
b = torch.randn(3)
result = mod.forward(x, w, b)
expected = F.relu(x @ w + b)
print("Pipeline result:", result)
print("Expected: ", expected)
assert torch.allclose(result, expected, atol=1e-4)
Step 4: Relax-to-Python Converter — Verify at Any Compilation Stage
Both Relax functions and Python functions describe computational graphs. The
RelaxToPyFuncConverter converts Relax IR into equivalent PyTorch code by mapping
Relax operators to their PyTorch counterparts (e.g., R.nn.relu → F.relu).
A key feature: this conversion can happen at any stage of compilation. You can convert early (right after import) or late (after optimization passes have transformed the IR), and compare the output against a PyTorch reference to catch bugs.
if RUN_EXAMPLE:
from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
# A simple Relax module: matmul + bias + relu (a dense layer)
@I.ir_module
class DenseLayer:
@T.prim_func
def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (2, 4), "float32")
b = T.match_buffer(var_b, (4,), "float32")
out = T.match_buffer(var_out, (2, 4), "float32")
for i, j in T.grid(2, 4):
out[i, j] = x[i, j] + b[j]
@R.function
def main(
x: R.Tensor((2, 4), "float32"),
w: R.Tensor((4, 4), "float32"),
b: R.Tensor((4,), "float32"),
) -> R.Tensor((2, 4), "float32"):
h = R.matmul(x, w)
cls = DenseLayer
h_bias = R.call_tir(
cls.bias_add_tir, (h, b),
out_sinfo=R.Tensor((2, 4), "float32"),
)
return R.nn.relu(h_bias)
# --- Stage 1: Convert BEFORE optimization ---
converter = RelaxToPyFuncConverter(DenseLayer)
converted_early = converter.convert(["main"])
x = torch.randn(2, 4)
w = torch.randn(4, 4)
b = torch.randn(4)
py_result_early = converted_early.pyfuncs["main"](x, w, b)
expected = F.relu(x @ w + b)
print("Before optimization:")
print(" Converted result:", py_result_early)
print(" PyTorch expected:", expected)
assert torch.allclose(py_result_early, expected, atol=1e-5)
# --- Stage 2: Apply a pass, then convert AFTER optimization ---
# Run CanonicalizeBindings to clean up the IR, then convert again
# to verify the pass did not break numerical correctness.
optimized_mod = relax.transform.CanonicalizeBindings()(DenseLayer)
converter_late = RelaxToPyFuncConverter(optimized_mod)
converted_late = converter_late.convert(["main"])
py_result_late = converted_late.pyfuncs["main"](x, w, b)
print("\nAfter CanonicalizeBindings pass:")
print(" Converted result:", py_result_late)
print(" Still matches: ",
torch.allclose(py_result_late, expected, atol=1e-5))
assert torch.allclose(py_result_late, expected, atol=1e-5)
Step 5: R.call_py_func — Python Callbacks in Compiled IR
R.call_py_func embeds a Python function call directly inside Relax IR. When the module
is compiled and run in the VM, everything else is optimized native code, but the VM calls
back into Python for the specified ops.
BasePyModule supports cross-level calls in both directions: Relax functions can invoke
Python functions, and Python functions can invoke TIR/Relax functions. Data flows between
them via DLPack with minimal overhead.
Use case: your model has a custom op (e.g., a special normalization or a sampling step) that is complex to implement in TIR. Compile everything else, and let that one op stay in Python.
if RUN_EXAMPLE:
@I.ir_module
class HybridVMModule(BasePyModule):
@I.pyfunc
def silu(self, x):
"""SiLU/Swish activation — using Python as fallback."""
return torch.sigmoid(x) * x
@I.pyfunc
def layer_norm(self, x):
"""LayerNorm — another Python fallback."""
return F.layer_norm(x, x.shape[-1:])
@R.function
def main(
x: R.Tensor((4, 8), "float32"),
) -> R.Tensor((4, 8), "float32"):
# The VM calls back into Python for these two ops
h = R.call_py_func(
"layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")
)
out = R.call_py_func(
"silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")
)
return out
mod = HybridVMModule(device=tvm.cpu(0))
x = torch.randn(4, 8)
# call_py_func is also callable from Python directly
result = mod.call_py_func("layer_norm", [x])
result = mod.call_py_func("silu", [result])
ln = F.layer_norm(x, x.shape[-1:])
expected = torch.sigmoid(ln) * ln
print("call_py_func result:", result)
assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5)
Step 6: Cross-Level Calls and Symbolic Shapes
BasePyModule is designed for cross-level interoperability: Python functions can call
TIR and Relax functions, and Relax functions can call Python functions. We have already seen:
Python → TIR via
call_tir(Steps 1–3)Python → packed function via
call_dps_packed(Step 3)Relax → Python via
R.call_py_func(Step 5)
The missing piece: Python calling a compiled Relax function directly. When a module
contains @R.function, it is JIT-compiled into a Relax VM. You can call it from Python
just like any other method — the module auto-converts PyTorch tensors to TVM and back.
This step also shows symbolic shapes: TIR and Relax functions can declare dynamic
dimensions (e.g., "n"). BasePyModule infers concrete shapes from the actual input
tensors at call time, so the same module handles different sizes without recompilation.
if RUN_EXAMPLE:
@I.ir_module
class DynamicModule(BasePyModule):
@T.prim_func
def scale_tir(var_x: T.handle, var_out: T.handle):
n = T.int64()
x = T.match_buffer(var_x, (n,), "float32")
out = T.match_buffer(var_out, (n,), "float32")
for i in T.serial(n):
out[i] = x[i] * T.float32(2.0)
@R.function
def add_relax(
x: R.Tensor(("n",), "float32"),
y: R.Tensor(("n",), "float32"),
) -> R.Tensor(("n",), "float32"):
return R.add(x, y)
mod = DynamicModule(device=tvm.cpu(0), target="llvm")
# Inspect what the module contains
print("Functions:", mod.list_functions())
# Python → Relax: call the compiled Relax function directly with PyTorch tensors
a5 = torch.randn(5)
b5 = torch.randn(5)
out5 = mod.add_relax(a5, b5)
print("add_relax(len=5):", out5)
# Same module, different size — symbolic shapes handle this automatically
a10 = torch.randn(10)
b10 = torch.randn(10)
out10 = mod.add_relax(a10, b10)
print("add_relax(len=10):", out10)
# Python → TIR with symbolic output shape
n = T.int64()
x7 = torch.randn(7)
scaled = mod.call_tir(
"scale_tir", [x7], relax.TensorStructInfo((n,), "float32")
)
print("scale_tir(len=7):", scaled)
assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)
Summary
Cross-level call summary:
Python → TIR:
call_tir()(Steps 1, 2, 3, 6)Python → packed function:
call_dps_packed()(Step 3)Python → Relax: call
@R.functionas a method (Step 6)Relax → Python:
R.call_py_func()in compiled VM (Step 5)
The workflow in practice:
Import a model → some ops unsupported → use
@I.pyfuncas Python fallbacksGet it running end-to-end with
BasePyModuleDebug by inserting
printin pyfuncs — inspect intermediate tensors instantlyUse
RelaxToPyFuncConverterto verify correctness after each optimization passGradually replace Python fallbacks with TIR/Relax implementations
Use
R.call_py_funcfor ops that must stay in Python even after compilation