Note
You can click here to run the Jupyter notebook locally.
Importing Models from ML Frameworks
Apache TVM supports importing models from popular ML frameworks including PyTorch, ONNX, and TensorFlow Lite. This tutorial walks through each import path with a minimal working example and explains the key parameters. The PyTorch section additionally demonstrates how to handle unsupported operators via a custom converter map.
For end-to-end optimization and deployment after importing, see End-to-End Optimize Model.
Note
The ONNX section requires the onnx package. The TFLite section requires
tensorflow and tflite. Sections whose dependencies are missing are skipped
automatically.
Importing from PyTorch (Recommended)
TVM’s PyTorch frontend is the most feature-complete. The recommended entry point is
from_exported_program(), which works with PyTorch’s
torch.export API.
We start by defining a small CNN model for demonstration. No pretrained weights are needed — we only care about the graph structure.
import numpy as np
import torch
from torch import nn
from torch.export import export
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = torch.relu(self.bn(self.conv(x)))
x = self.pool(x).flatten(1)
x = self.fc(x)
return x
torch_model = SimpleCNN().eval()
example_args = (torch.randn(1, 3, 32, 32),)
Basic import
The standard workflow is: torch.export.export() → from_exported_program() →
detach_params().
with torch.no_grad():
exported_program = export(torch_model, example_args)
mod = from_exported_program(
exported_program,
keep_params_as_input=True,
unwrap_unit_return_tuple=True,
)
mod, params = relax.frontend.detach_params(mod)
mod.show()
/opt/uv/python/cpython-3.10-linux-x86_64-gnu/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 3, 32, 32), dtype="float32"), p_conv_weight: R.Tensor((16, 3, 3, 3), dtype="float32"), p_conv_bias: R.Tensor((16,), dtype="float32"), p_bn_weight: R.Tensor((16,), dtype="float32"), p_bn_bias: R.Tensor((16,), dtype="float32"), p_fc_weight: R.Tensor((10, 16), dtype="float32"), p_fc_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, p_conv_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(p_conv_bias, R.shape([1, 16, 1, 1]))
lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv, lv1)
lv3: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32"), R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")) = R.nn.batch_norm(lv2, p_bn_weight, p_bn_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = lv3[0]
lv5: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
lv6: R.Tensor((1, 16, 1, 1), dtype="float32") = R.mean(lv5, axis=[-1, -2], keepdims=True)
lv7: R.Tensor((1, 16), dtype="float32") = R.reshape(lv6, R.shape([1, 16]))
lv8: R.Tensor((16, 10), dtype="float32") = R.permute_dims(p_fc_weight, axes=[1, 0])
lv9: R.Tensor((1, 10), dtype="float32") = R.matmul(lv7, lv8, out_dtype="float32")
lv10: R.Tensor((1, 10), dtype="float32") = R.add(p_fc_bias, lv9)
gv: R.Tensor((1, 10), dtype="float32") = lv10
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Key parameters
from_exported_program accepts several parameters that control how the model is
translated:
keep_params_as_input (bool, default
False): WhenTrue, model weights become function parameters, separated viarelax.frontend.detach_params(). WhenFalse, weights are embedded as constants inside the IRModule. UseTruewhen you want to manage weights independently (e.g., for weight sharing or quantization).unwrap_unit_return_tuple (bool, default
False): PyTorchexportalways wraps the return value in a tuple. SetTrueto unwrap single-element return tuples for a cleaner Relax function signature.run_ep_decomposition (bool, default
True): Runs PyTorch’s built-in operator decomposition before translation. This breaks high-level ops (e.g.,batch_norm) into lower-level primitives, which generally improves TVM’s coverage and optimization opportunities. SetFalseif you want to preserve the original op granularity.
Handling unsupported operators with custom_convert_map
When TVM encounters a PyTorch operator it does not recognize, it raises an error indicating the unsupported operator name. You can extend the frontend by providing a custom converter map — a dictionary mapping operator names to your own conversion functions.
A custom converter function receives two arguments:
node (
torch.fx.Node): The FX graph node being converted, carrying operator info and references to input nodes.importer (
ExportedProgramImporter): The importer instance, giving access to:importer.env: Dict mapping FX nodes to their converted Relax expressions.importer.block_builder: The RelaxBlockBuilderfor emitting operations.importer.retrieve_args(node): Helper to look up converted args.
The function must return a relax.Var — the Relax expression for this node’s output.
Here is an example that maps an operator to relax.op.sigmoid:
from tvm.relax.frontend.torch.exported_program_translator import ExportedProgramImporter
def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) -> relax.Var:
"""Custom converter: map an op to relax.op.sigmoid."""
args = importer.retrieve_args(node)
return importer.block_builder.emit(relax.op.sigmoid(args[0]))
To use the custom converter, pass it via the custom_convert_map parameter. The key
is the ATen operator name in "op_name.variant" format (e.g., "sigmoid.default"):
mod = from_exported_program(
exported_program,
custom_convert_map={"sigmoid.default": convert_sigmoid},
)
Note
To find the correct operator name, check the error message TVM raises when encountering
the unsupported op — it includes the exact ATen name. You can also inspect the exported
program’s graph via print(exported_program.graph_module.graph) to see all operator
names.
Alternative PyTorch import methods
Besides from_exported_program, TVM also provides:
from_fx(): Works withtorch.fx.GraphModulefromtorch.fx.symbolic_trace(). Requires explicitinput_info(shapes and dtypes). Use this whentorch.exportfails on certain Python control flow patterns.relax_dynamo(): Atorch.compilebackend that compiles and executes the model through TVM in one step. Useful for integrating TVM into an existing PyTorch training or inference loop.dynamo_capture_subgraphs(): Captures subgraphs from a PyTorch model into an IRModule viatorch.compile. Each subgraph becomes a separate function in the IRModule.
For most use cases, from_exported_program is the recommended path.
Verifying the imported model
After importing, it is good practice to verify that TVM produces the same output as the
original framework. We compile with the minimal "zero" pipeline (no tuning) and
compare. The same approach applies to models imported via the ONNX and TFLite frontends
shown below.
mod_compiled = relax.get_pipeline("zero")(mod)
exec_module = tvm.compile(mod_compiled, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec_module, dev)
# Run inference
input_data = np.random.rand(1, 3, 32, 32).astype("float32")
tvm_input = tvm.runtime.tensor(input_data, dev)
tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
tvm_out = vm["main"](tvm_input, *tvm_params).numpy()
# Compare with PyTorch
with torch.no_grad():
pt_out = torch_model(torch.from_numpy(input_data)).numpy()
np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5)
print("PyTorch vs TVM outputs match!")
PyTorch vs TVM outputs match!
Importing from ONNX
TVM can import ONNX models via from_onnx(). The
function accepts an onnx.ModelProto object, so you need to load the model with
onnx.load() first.
Here we export the same CNN model to ONNX format and then import it into TVM.
try:
import onnx
import onnxscript # noqa: F401 # required by torch.onnx.export
HAS_ONNX = True
except ImportError:
onnx = None # type: ignore[assignment]
HAS_ONNX = False
if HAS_ONNX:
from tvm.relax.frontend.onnx import from_onnx
# Export the PyTorch model to ONNX
dummy_input = torch.randn(1, 3, 32, 32)
onnx_path = "simple_cnn.onnx"
torch.onnx.export(torch_model, dummy_input, onnx_path, input_names=["input"])
# Load and import into TVM
onnx_model = onnx.load(onnx_path)
mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
mod_onnx.show()
If you already have an .onnx file on disk, the workflow is even simpler:
Key parameters
shape_dict (dict, optional): Maps input names to shapes. Auto-inferred from the model if not provided. Useful when the ONNX model has dynamic dimensions that you want to fix to concrete sizes:
mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})
dtype_dict (str or dict, default
"float32"): Input dtypes. A single string applies to all inputs, or use a dict to set per-input dtypes:mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})
keep_params_in_input (bool, default
False): Same semantics as PyTorch — whether model weights are function parameters or embedded constants.opset (int, optional): Override the opset version auto-detected from the model. Each ONNX op may have different semantics across opset versions; TVM’s converter selects the appropriate implementation automatically. You rarely need to set this unless the model metadata is incorrect.
Importing from TensorFlow Lite
TVM can import TFLite flat buffer models via
from_tflite(). The function expects a TFLite
Model object parsed from flat buffer bytes via GetRootAsModel.
Note
The tflite Python package has changed its module layout across versions.
Older versions use tflite.Model.Model.GetRootAsModel, while newer versions use
tflite.Model.GetRootAsModel. The code below handles both.
Below we create a minimal TFLite model from TensorFlow and import it.
try:
import tensorflow as tf
import tflite
import tflite.Model
HAS_TFLITE = True
except ImportError:
HAS_TFLITE = False
if HAS_TFLITE:
from tvm.relax.frontend.tflite import from_tflite
# Define a simple TF module and convert to TFLite.
# We use plain TF ops (not keras layers) to avoid variable-handling ops
# that some TFLite converter versions do not support cleanly.
class TFModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
]
)
def forward(self, x, weight):
return tf.matmul(x, weight) + 0.1
tf_module = TFModule()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_module.forward.get_concrete_function()], tf_module
)
tflite_buf = converter.convert()
# Parse and import into TVM (API differs between tflite package versions)
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
mod_tflite = from_tflite(tflite_model)
mod_tflite.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(serving_default_weight_0: R.Tensor((784, 10), dtype="float32"), serving_default_x_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 2, "params": [metadata["ffi.Tensor"][0]]})
with R.dataflow():
lv: R.Tensor((10, 784), dtype="float32") = R.permute_dims(serving_default_weight_0, axes=[1, 0])
lv1: R.Tensor((784, 10), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
lv2: R.Tensor((1, 10), dtype="float32") = R.matmul(serving_default_x_0, lv1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][0])
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Loading from a .tflite file
If you already have a .tflite file on disk, load the raw bytes and parse them:
import tflite
import tflite.Model
from tvm.relax.frontend.tflite import from_tflite
with open("my_model.tflite", "rb") as f:
tflite_buf = f.read()
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
mod = from_tflite(tflite_model)
Key parameters
shape_dict / dtype_dict (optional): Override input shapes and dtypes. If not provided, they are inferred from the TFLite model metadata.
op_converter (class, optional): A custom operator converter class. Subclass
OperatorConverterand override itsconvert_mapdictionary to add or replace operator conversions. For example, to add a hypotheticalCUSTOM_RELUop:from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter class MyConverter(OperatorConverter): def __init__(self, model, subgraph, exp_tab, ctx): super().__init__(model, subgraph, exp_tab, ctx) self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu def _convert_custom_relu(self, op): # implement your conversion logic here ... mod = from_tflite(tflite_model, op_converter=MyConverter)
Summary
Aspect |
PyTorch |
ONNX |
TFLite |
|---|---|---|---|
Entry function |
|
|
|
Input |
|
|
TFLite |
Custom extension |
|
— |
|
Which to use? Pick the frontend that matches your model format:
Have a PyTorch model? Use
from_exported_program— it has the broadest operator coverage.Have an
.onnxfile? Usefrom_onnx.Have a
.tflitefile? Usefrom_tflite.
The verification workflow (compile → run → compare) demonstrated in the PyTorch section above applies equally to ONNX and TFLite imports.
For the full list of supported operators, see the converter map in each frontend’s source:
PyTorch uses create_convert_map() in exported_program_translator.py, ONNX uses
_get_convert_map() in onnx_frontend.py, and TFLite uses convert_map in
OperatorConverter in tflite_frontend.py.
After importing, refer to End-to-End Optimize Model for optimization and deployment.