Mix Python/PyTorch with TVM Using BasePyModule

In a typical TVM workflow, you write an IRModule, compile it, and load the compiled artifact into a VirtualMachine to run. This means you cannot test or debug anything until the entire module compiles successfully. If a single op is unsupported, the whole pipeline is blocked.

BasePyModule solves this by letting Python functions, TIR kernels, and Relax functions coexist in one module. TIR and Relax functions are JIT-compiled on instantiation, Python functions run as-is, and tensors move between TVM and PyTorch via zero-copy DLPack. This enables:

  • Incremental development: get a model running with Python fallbacks first, then replace them with TVM ops one by one.

  • Easy debugging: insert print in Python functions to inspect intermediate tensors — no need to compile the whole module first.

  • Verification at any compilation stage: convert Relax IR back to PyTorch to check numerical correctness before and after optimization passes.

  • Hybrid execution: let the compiled VM call back into Python for ops that are hard to express in TIR or Relax.

This tutorial walks through the full workflow step by step.

Preparation

import os

try:
    import torch
    import torch.nn.functional as F
except ImportError:
    torch = None

import tvm
from tvm import relax
from tvm.relax.base_py_module import BasePyModule
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T

IS_IN_CI = os.getenv("CI", "").lower() == "true"
HAS_TORCH = torch is not None
RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI

Step 1: Your First Hybrid Module

The core idea: decorate a class with @I.ir_module, inherit from BasePyModule, and use three decorators for three kinds of functions:

  • @T.prim_func — low-level TIR kernel (JIT-compiled on instantiation)

  • @R.function — high-level Relax graph (JIT-compiled on instantiation)

  • @I.pyfunc — plain Python (runs as-is, can use any Python library)

call_tir bridges Python and TIR: it converts PyTorch tensors to TVM NDArrays via DLPack (zero-copy), allocates the output buffer, calls the compiled kernel, and converts back.

if RUN_EXAMPLE:

    @I.ir_module
    class MyFirstModule(BasePyModule):
        @T.prim_func
        def add_tir(
            A: T.Buffer((4,), "float32"),
            B: T.Buffer((4,), "float32"),
            C: T.Buffer((4,), "float32"),
        ):
            for i in range(4):
                C[i] = A[i] + B[i]

        @I.pyfunc
        def forward(self, x, y):
            """Takes PyTorch tensors, calls TIR, returns PyTorch tensors."""
            x_tvm = self._convert_pytorch_to_tvm(x)
            y_tvm = self._convert_pytorch_to_tvm(y)
            result = self.call_tir(
                self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32")
            )
            return self._convert_tvm_to_pytorch(result)

    # TIR functions are JIT-compiled at instantiation
    mod = MyFirstModule(device=tvm.cpu(0))

    x = torch.tensor([1.0, 2.0, 3.0, 4.0])
    y = torch.tensor([10.0, 20.0, 30.0, 40.0])
    result = mod.forward(x, y)

    print("forward(x, y) =", result)
    assert torch.allclose(result, x + y)

    # show() prints TVMScript including Python functions (shown as ExternFunc)
    mod.show()

    # list_functions() shows what is available in the module
    print("Available functions:", mod.list_functions())

Step 2: Debugging — The Main Selling Point

Traditional ML compilers treat computation graphs as monolithic blobs. You cannot inspect intermediate tensor values without compiling the entire module. With @I.pyfunc, debugging is as simple as adding a print statement. You can also make quick edits and re-run immediately — no recompilation needed.

if RUN_EXAMPLE:

    @I.ir_module
    class DebugModule(BasePyModule):
        @T.prim_func
        def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
            n = T.int32()
            A = T.match_buffer(var_A, (n, 4), "float32")
            B = T.match_buffer(var_B, (4, 3), "float32")
            C = T.match_buffer(var_C, (n, 3), "float32")
            for i, j, k in T.grid(n, 3, 4):
                with T.sblock("matmul"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    with T.init():
                        C[vi, vj] = T.float32(0)
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

        @I.pyfunc
        def forward(self, x, weights):
            # Inspect input
            print(f"  [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}")

            # Run TIR matmul
            x_tvm = self._convert_pytorch_to_tvm(x)
            w_tvm = self._convert_pytorch_to_tvm(weights)
            out = self.call_tir(
                self.matmul_tir,
                [x_tvm, w_tvm],
                out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
            )
            logits = self._convert_tvm_to_pytorch(out)

            # Inspect intermediate value — impossible with a compiled-only workflow
            print(f"  [DEBUG] logits shape: {logits.shape}, "
                  f"min: {logits.min():.4f}, max: {logits.max():.4f}")

            result = F.softmax(logits, dim=-1)

            # Verify output
            print(f"  [DEBUG] probs sum: {result.sum(dim=-1)}")
            return result

    mod = DebugModule(device=tvm.cpu(0))

    print("Running with debug prints:")
    probs = mod.forward(torch.randn(2, 4), torch.randn(4, 3))
    assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5)

This is the key benefit: “debugging is as simple as inserting a print statement. Users can also make quick, manual edits to Python functions and immediately observe the results.” No compilation cycle, no VM loading — just Python.

Step 3: A Realistic Pipeline — Python, TIR, and Packed Functions

Real models combine many kinds of operations. This step builds a mini inference pipeline using three different calling conventions:

  • call_tir: call a compiled TIR kernel

  • call_dps_packed: call a TVM packed function (e.g., a third-party library binding)

  • Direct Python: call any PyTorch function

call_dps_packed is useful for calling functions registered via tvm.register_global_func — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions.

if RUN_EXAMPLE:

    # Register a packed function (simulating an external library binding)
    @tvm.register_global_func("my_bias_add", override=True)
    def my_bias_add(x, bias, out):
        """Packed function: adds bias to each row of x."""
        import numpy as np

        x_np = x.numpy()
        b_np = bias.numpy()
        out_np = x_np + b_np
        out[:] = out_np

    @I.ir_module
    class PipelineModule(BasePyModule):
        @T.prim_func
        def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
            A = T.match_buffer(var_A, (2, 4), "float32")
            B = T.match_buffer(var_B, (4, 3), "float32")
            C = T.match_buffer(var_C, (2, 3), "float32")
            for i, j, k in T.grid(2, 3, 4):
                with T.sblock("matmul"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    with T.init():
                        C[vi, vj] = T.float32(0)
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

        @I.pyfunc
        def forward(self, x, weights, bias):
            # 1. TIR matmul
            x_tvm = self._convert_pytorch_to_tvm(x)
            w_tvm = self._convert_pytorch_to_tvm(weights)
            h = self.call_tir(
                self.matmul_tir, [x_tvm, w_tvm],
                out_sinfo=R.Tensor((2, 3), "float32"),
            )
            h_pt = self._convert_tvm_to_pytorch(h)

            # 2. Packed function for bias add (simulating an external library)
            h_biased = self.call_dps_packed(
                "my_bias_add", [h_pt, bias],
                out_sinfo=R.Tensor((2, 3), "float32"),
            )

            # 3. Python/PyTorch activation
            return F.relu(h_biased)

    mod = PipelineModule(device=tvm.cpu(0))

    x = torch.randn(2, 4)
    w = torch.randn(4, 3)
    b = torch.randn(3)
    result = mod.forward(x, w, b)

    expected = F.relu(x @ w + b)
    print("Pipeline result:", result)
    print("Expected:       ", expected)
    assert torch.allclose(result, expected, atol=1e-4)

Step 4: Relax-to-Python Converter — Verify at Any Compilation Stage

Both Relax functions and Python functions describe computational graphs. The RelaxToPyFuncConverter converts Relax IR into equivalent PyTorch code by mapping Relax operators to their PyTorch counterparts (e.g., R.nn.reluF.relu).

A key feature: this conversion can happen at any stage of compilation. You can convert early (right after import) or late (after optimization passes have transformed the IR), and compare the output against a PyTorch reference to catch bugs.

if RUN_EXAMPLE:
    from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter

    # A simple Relax module: matmul + bias + relu (a dense layer)
    @I.ir_module
    class DenseLayer:
        @T.prim_func
        def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle):
            x = T.match_buffer(var_x, (2, 4), "float32")
            b = T.match_buffer(var_b, (4,), "float32")
            out = T.match_buffer(var_out, (2, 4), "float32")
            for i, j in T.grid(2, 4):
                out[i, j] = x[i, j] + b[j]

        @R.function
        def main(
            x: R.Tensor((2, 4), "float32"),
            w: R.Tensor((4, 4), "float32"),
            b: R.Tensor((4,), "float32"),
        ) -> R.Tensor((2, 4), "float32"):
            h = R.matmul(x, w)
            cls = DenseLayer
            h_bias = R.call_tir(
                cls.bias_add_tir, (h, b),
                out_sinfo=R.Tensor((2, 4), "float32"),
            )
            return R.nn.relu(h_bias)

    # --- Stage 1: Convert BEFORE optimization ---
    converter = RelaxToPyFuncConverter(DenseLayer)
    converted_early = converter.convert(["main"])

    x = torch.randn(2, 4)
    w = torch.randn(4, 4)
    b = torch.randn(4)

    py_result_early = converted_early.pyfuncs["main"](x, w, b)
    expected = F.relu(x @ w + b)

    print("Before optimization:")
    print("  Converted result:", py_result_early)
    print("  PyTorch expected:", expected)
    assert torch.allclose(py_result_early, expected, atol=1e-5)

    # --- Stage 2: Apply a pass, then convert AFTER optimization ---
    # Run CanonicalizeBindings to clean up the IR, then convert again
    # to verify the pass did not break numerical correctness.
    optimized_mod = relax.transform.CanonicalizeBindings()(DenseLayer)

    converter_late = RelaxToPyFuncConverter(optimized_mod)
    converted_late = converter_late.convert(["main"])

    py_result_late = converted_late.pyfuncs["main"](x, w, b)

    print("\nAfter CanonicalizeBindings pass:")
    print("  Converted result:", py_result_late)
    print("  Still matches:   ",
          torch.allclose(py_result_late, expected, atol=1e-5))
    assert torch.allclose(py_result_late, expected, atol=1e-5)

Step 5: R.call_py_func — Python Callbacks in Compiled IR

R.call_py_func embeds a Python function call directly inside Relax IR. When the module is compiled and run in the VM, everything else is optimized native code, but the VM calls back into Python for the specified ops.

BasePyModule supports cross-level calls in both directions: Relax functions can invoke Python functions, and Python functions can invoke TIR/Relax functions. Data flows between them via DLPack with minimal overhead.

Use case: your model has a custom op (e.g., a special normalization or a sampling step) that is complex to implement in TIR. Compile everything else, and let that one op stay in Python.

if RUN_EXAMPLE:

    @I.ir_module
    class HybridVMModule(BasePyModule):
        @I.pyfunc
        def silu(self, x):
            """SiLU/Swish activation — using Python as fallback."""
            return torch.sigmoid(x) * x

        @I.pyfunc
        def layer_norm(self, x):
            """LayerNorm — another Python fallback."""
            return F.layer_norm(x, x.shape[-1:])

        @R.function
        def main(
            x: R.Tensor((4, 8), "float32"),
        ) -> R.Tensor((4, 8), "float32"):
            # The VM calls back into Python for these two ops
            h = R.call_py_func(
                "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")
            )
            out = R.call_py_func(
                "silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")
            )
            return out

    mod = HybridVMModule(device=tvm.cpu(0))
    x = torch.randn(4, 8)

    # call_py_func is also callable from Python directly
    result = mod.call_py_func("layer_norm", [x])
    result = mod.call_py_func("silu", [result])

    ln = F.layer_norm(x, x.shape[-1:])
    expected = torch.sigmoid(ln) * ln
    print("call_py_func result:", result)
    assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5)

Step 6: Cross-Level Calls and Symbolic Shapes

BasePyModule is designed for cross-level interoperability: Python functions can call TIR and Relax functions, and Relax functions can call Python functions. We have already seen:

  • Python → TIR via call_tir (Steps 1–3)

  • Python → packed function via call_dps_packed (Step 3)

  • Relax → Python via R.call_py_func (Step 5)

The missing piece: Python calling a compiled Relax function directly. When a module contains @R.function, it is JIT-compiled into a Relax VM. You can call it from Python just like any other method — the module auto-converts PyTorch tensors to TVM and back.

This step also shows symbolic shapes: TIR and Relax functions can declare dynamic dimensions (e.g., "n"). BasePyModule infers concrete shapes from the actual input tensors at call time, so the same module handles different sizes without recompilation.

if RUN_EXAMPLE:

    @I.ir_module
    class DynamicModule(BasePyModule):
        @T.prim_func
        def scale_tir(var_x: T.handle, var_out: T.handle):
            n = T.int64()
            x = T.match_buffer(var_x, (n,), "float32")
            out = T.match_buffer(var_out, (n,), "float32")
            for i in T.serial(n):
                out[i] = x[i] * T.float32(2.0)

        @R.function
        def add_relax(
            x: R.Tensor(("n",), "float32"),
            y: R.Tensor(("n",), "float32"),
        ) -> R.Tensor(("n",), "float32"):
            return R.add(x, y)

    mod = DynamicModule(device=tvm.cpu(0), target="llvm")

    # Inspect what the module contains
    print("Functions:", mod.list_functions())

    # Python → Relax: call the compiled Relax function directly with PyTorch tensors
    a5 = torch.randn(5)
    b5 = torch.randn(5)
    out5 = mod.add_relax(a5, b5)
    print("add_relax(len=5):", out5)

    # Same module, different size — symbolic shapes handle this automatically
    a10 = torch.randn(10)
    b10 = torch.randn(10)
    out10 = mod.add_relax(a10, b10)
    print("add_relax(len=10):", out10)

    # Python → TIR with symbolic output shape
    n = T.int64()
    x7 = torch.randn(7)
    scaled = mod.call_tir(
        "scale_tir", [x7], relax.TensorStructInfo((n,), "float32")
    )
    print("scale_tir(len=7):", scaled)
    assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)

Summary

Cross-level call summary:

  • Python → TIR: call_tir() (Steps 1, 2, 3, 6)

  • Python → packed function: call_dps_packed() (Step 3)

  • Python → Relax: call @R.function as a method (Step 6)

  • Relax → Python: R.call_py_func() in compiled VM (Step 5)

The workflow in practice:

  1. Import a model → some ops unsupported → use @I.pyfunc as Python fallbacks

  2. Get it running end-to-end with BasePyModule

  3. Debug by inserting print in pyfuncs — inspect intermediate tensors instantly

  4. Use RelaxToPyFuncConverter to verify correctness after each optimization pass

  5. Gradually replace Python fallbacks with TIR/Relax implementations

  6. Use R.call_py_func for ops that must stay in Python even after compilation

Gallery generated by Sphinx-Gallery