Importing Models from ML Frameworks

Apache TVM supports importing models from popular ML frameworks including PyTorch, ONNX, and TensorFlow Lite. This tutorial walks through each import path with a minimal working example and explains the key parameters. The PyTorch section additionally demonstrates how to handle unsupported operators via a custom converter map.

For end-to-end optimization and deployment after importing, see End-to-End Optimize Model.

Note

The ONNX section requires the onnx package. The TFLite section requires tensorflow and tflite. Sections whose dependencies are missing are skipped automatically.

Importing from ONNX

TVM can import ONNX models via from_onnx(). The function accepts an onnx.ModelProto object, so you need to load the model with onnx.load() first.

Here we export the same CNN model to ONNX format and then import it into TVM.

try:
    import onnx
    import onnxscript  # noqa: F401  # required by torch.onnx.export

    HAS_ONNX = True
except ImportError:
    onnx = None  # type: ignore[assignment]
    HAS_ONNX = False

if HAS_ONNX:
    from tvm.relax.frontend.onnx import from_onnx

    # Export the PyTorch model to ONNX
    dummy_input = torch.randn(1, 3, 32, 32)
    onnx_path = "simple_cnn.onnx"
    torch.onnx.export(torch_model, dummy_input, onnx_path, input_names=["input"])

    # Load and import into TVM
    onnx_model = onnx.load(onnx_path)
    mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
    mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
    mod_onnx.show()

If you already have an .onnx file on disk, the workflow is even simpler:

import onnx
from tvm.relax.frontend.onnx import from_onnx

onnx_model = onnx.load("my_model.onnx")
mod = from_onnx(onnx_model)

Key parameters

  • shape_dict (dict, optional): Maps input names to shapes. Auto-inferred from the model if not provided. Useful when the ONNX model has dynamic dimensions that you want to fix to concrete sizes:

    mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})
    
  • dtype_dict (str or dict, default "float32"): Input dtypes. A single string applies to all inputs, or use a dict to set per-input dtypes:

    mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})
    
  • keep_params_in_input (bool, default False): Same semantics as PyTorch — whether model weights are function parameters or embedded constants.

  • opset (int, optional): Override the opset version auto-detected from the model. Each ONNX op may have different semantics across opset versions; TVM’s converter selects the appropriate implementation automatically. You rarely need to set this unless the model metadata is incorrect.

Importing from TensorFlow Lite

TVM can import TFLite flat buffer models via from_tflite(). The function expects a TFLite Model object parsed from flat buffer bytes via GetRootAsModel.

Note

The tflite Python package has changed its module layout across versions. Older versions use tflite.Model.Model.GetRootAsModel, while newer versions use tflite.Model.GetRootAsModel. The code below handles both.

Below we create a minimal TFLite model from TensorFlow and import it.

try:
    import tensorflow as tf
    import tflite
    import tflite.Model

    HAS_TFLITE = True
except ImportError:
    HAS_TFLITE = False

if HAS_TFLITE:
    from tvm.relax.frontend.tflite import from_tflite

    # Define a simple TF module and convert to TFLite.
    # We use plain TF ops (not keras layers) to avoid variable-handling ops
    # that some TFLite converter versions do not support cleanly.
    class TFModule(tf.Module):
        @tf.function(
            input_signature=[
                tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
                tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
            ]
        )
        def forward(self, x, weight):
            return tf.matmul(x, weight) + 0.1

    tf_module = TFModule()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [tf_module.forward.get_concrete_function()], tf_module
    )
    tflite_buf = converter.convert()

    # Parse and import into TVM (API differs between tflite package versions)
    if hasattr(tflite.Model, "Model"):
        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
    else:
        tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
    mod_tflite = from_tflite(tflite_model)
    mod_tflite.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(serving_default_weight_0: R.Tensor((784, 10), dtype="float32"), serving_default_x_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 2, "params": [metadata["ffi.Tensor"][0]]})
        with R.dataflow():
            lv: R.Tensor((10, 784), dtype="float32") = R.permute_dims(serving_default_weight_0, axes=[1, 0])
            lv1: R.Tensor((784, 10), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
            lv2: R.Tensor((1, 10), dtype="float32") = R.matmul(serving_default_x_0, lv1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][0])
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Loading from a .tflite file

If you already have a .tflite file on disk, load the raw bytes and parse them:

import tflite
import tflite.Model
from tvm.relax.frontend.tflite import from_tflite

with open("my_model.tflite", "rb") as f:
    tflite_buf = f.read()

if hasattr(tflite.Model, "Model"):
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
else:
    tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
mod = from_tflite(tflite_model)

Key parameters

  • shape_dict / dtype_dict (optional): Override input shapes and dtypes. If not provided, they are inferred from the TFLite model metadata.

  • op_converter (class, optional): A custom operator converter class. Subclass OperatorConverter and override its convert_map dictionary to add or replace operator conversions. For example, to add a hypothetical CUSTOM_RELU op:

    from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter
    
    class MyConverter(OperatorConverter):
        def __init__(self, model, subgraph, exp_tab, ctx):
            super().__init__(model, subgraph, exp_tab, ctx)
            self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu
    
        def _convert_custom_relu(self, op):
            # implement your conversion logic here
            ...
    
    mod = from_tflite(tflite_model, op_converter=MyConverter)
    

Summary

Aspect

PyTorch

ONNX

TFLite

Entry function

from_exported_program

from_onnx

from_tflite

Input

ExportedProgram

onnx.ModelProto

TFLite Model object

Custom extension

custom_convert_map

op_converter class

Which to use? Pick the frontend that matches your model format:

  • Have a PyTorch model? Use from_exported_program — it has the broadest operator coverage.

  • Have an .onnx file? Use from_onnx.

  • Have a .tflite file? Use from_tflite.

The verification workflow (compile → run → compare) demonstrated in the PyTorch section above applies equally to ONNX and TFLite imports.

For the full list of supported operators, see the converter map in each frontend’s source: PyTorch uses create_convert_map() in exported_program_translator.py, ONNX uses _get_convert_map() in onnx_frontend.py, and TFLite uses convert_map in OperatorConverter in tflite_frontend.py.

After importing, refer to End-to-End Optimize Model for optimization and deployment.

Gallery generated by Sphinx-Gallery