
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/tutorials/import_model.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_how_to_tutorials_import_model.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_how_to_tutorials_import_model.py:


.. _import_model:

Importing Models from ML Frameworks
====================================
Apache TVM supports importing models from popular ML frameworks including PyTorch, ONNX,
and TensorFlow Lite. This tutorial walks through each import path with a minimal working
example and explains the key parameters. The PyTorch section additionally demonstrates
how to handle unsupported operators via a custom converter map.

For end-to-end optimization and deployment after importing, see :ref:`optimize_model`.

.. note::

    The ONNX section requires the ``onnx`` package. The TFLite section requires
    ``tensorflow`` and ``tflite``. Sections whose dependencies are missing are skipped
    automatically.

.. contents:: Table of Contents
    :local:
    :depth: 2

.. GENERATED FROM PYTHON SOURCE LINES 43-51

Importing from PyTorch (Recommended)
-------------------------------------
TVM's PyTorch frontend is the most feature-complete. The recommended entry point is
:py:func:`~tvm.relax.frontend.torch.from_exported_program`, which works with PyTorch's
``torch.export`` API.

We start by defining a small CNN model for demonstration. No pretrained weights are
needed — we only care about the graph structure.

.. GENERATED FROM PYTHON SOURCE LINES 51-80

.. code-block:: Python


    import numpy as np
    import torch
    from torch import nn
    from torch.export import export

    import tvm
    from tvm import relax
    from tvm.relax.frontend.torch import from_exported_program


    class SimpleCNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
            self.bn = nn.BatchNorm2d(16)
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(16, 10)

        def forward(self, x):
            x = torch.relu(self.bn(self.conv(x)))
            x = self.pool(x).flatten(1)
            x = self.fc(x)
            return x


    torch_model = SimpleCNN().eval()
    example_args = (torch.randn(1, 3, 32, 32),)








.. GENERATED FROM PYTHON SOURCE LINES 81-85

Basic import
~~~~~~~~~~~~
The standard workflow is: ``torch.export.export()`` → ``from_exported_program()`` →
``detach_params()``.

.. GENERATED FROM PYTHON SOURCE LINES 85-97

.. code-block:: Python


    with torch.no_grad():
        exported_program = export(torch_model, example_args)
        mod = from_exported_program(
            exported_program,
            keep_params_as_input=True,
            unwrap_unit_return_tuple=True,
        )

    mod, params = relax.frontend.detach_params(mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /opt/uv/python/cpython-3.10-linux-x86_64-gnu/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    # from tvm.script import ir as I
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor((1, 3, 32, 32), dtype="float32"), p_conv_weight: R.Tensor((16, 3, 3, 3), dtype="float32"), p_conv_bias: R.Tensor((16,), dtype="float32"), p_bn_weight: R.Tensor((16,), dtype="float32"), p_bn_bias: R.Tensor((16,), dtype="float32"), p_fc_weight: R.Tensor((10, 16), dtype="float32"), p_fc_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, p_conv_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv1: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(p_conv_bias, R.shape([1, 16, 1, 1]))
                lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv, lv1)
                lv3: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32"), R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")) = R.nn.batch_norm(lv2, p_bn_weight, p_bn_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
                lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = lv3[0]
                lv5: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
                lv6: R.Tensor((1, 16, 1, 1), dtype="float32") = R.mean(lv5, axis=[-1, -2], keepdims=True)
                lv7: R.Tensor((1, 16), dtype="float32") = R.reshape(lv6, R.shape([1, 16]))
                lv8: R.Tensor((16, 10), dtype="float32") = R.permute_dims(p_fc_weight, axes=[1, 0])
                lv9: R.Tensor((1, 10), dtype="float32") = R.matmul(lv7, lv8, out_dtype="float32")
                lv10: R.Tensor((1, 10), dtype="float32") = R.add(p_fc_bias, lv9)
                gv: R.Tensor((1, 10), dtype="float32") = lv10
                R.output(gv)
            return gv

    # Metadata omitted. Use show_meta=True in script() method to show it.





.. GENERATED FROM PYTHON SOURCE LINES 98-116

Key parameters
~~~~~~~~~~~~~~
``from_exported_program`` accepts several parameters that control how the model is
translated:

- **keep_params_as_input** (bool, default ``False``): When ``True``, model weights become
  function parameters, separated via ``relax.frontend.detach_params()``. When ``False``,
  weights are embedded as constants inside the IRModule. Use ``True`` when you want to
  manage weights independently (e.g., for weight sharing or quantization).

- **unwrap_unit_return_tuple** (bool, default ``False``): PyTorch ``export`` always wraps
  the return value in a tuple. Set ``True`` to unwrap single-element return tuples for a
  cleaner Relax function signature.

- **run_ep_decomposition** (bool, default ``True``): Runs PyTorch's built-in operator
  decomposition before translation. This breaks high-level ops (e.g., ``batch_norm``) into
  lower-level primitives, which generally improves TVM's coverage and optimization
  opportunities. Set ``False`` if you want to preserve the original op granularity.

.. GENERATED FROM PYTHON SOURCE LINES 118-137

Handling unsupported operators with ``custom_convert_map``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When TVM encounters a PyTorch operator it does not recognize, it raises an error
indicating the unsupported operator name. You can extend the frontend by providing a
**custom converter map** — a dictionary mapping operator names to your own conversion
functions.

A custom converter function receives two arguments:

- **node** (``torch.fx.Node``): The FX graph node being converted, carrying operator
  info and references to input nodes.
- **importer** (``ExportedProgramImporter``): The importer instance, giving access to:

  - ``importer.env``: Dict mapping FX nodes to their converted Relax expressions.
  - ``importer.block_builder``: The Relax ``BlockBuilder`` for emitting operations.
  - ``importer.retrieve_args(node)``: Helper to look up converted args.

The function must return a ``relax.Var`` — the Relax expression for this node's output.
Here is an example that maps an operator to ``relax.op.sigmoid``:

.. GENERATED FROM PYTHON SOURCE LINES 137-147

.. code-block:: Python


    from tvm.relax.frontend.torch.exported_program_translator import ExportedProgramImporter


    def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) -> relax.Var:
        """Custom converter: map an op to relax.op.sigmoid."""
        args = importer.retrieve_args(node)
        return importer.block_builder.emit(relax.op.sigmoid(args[0]))









.. GENERATED FROM PYTHON SOURCE LINES 148-164

To use the custom converter, pass it via the ``custom_convert_map`` parameter. The key
is the ATen operator name in ``"op_name.variant"`` format (e.g., ``"sigmoid.default"``):

.. code-block:: python

   mod = from_exported_program(
       exported_program,
       custom_convert_map={"sigmoid.default": convert_sigmoid},
   )

.. note::

   To find the correct operator name, check the error message TVM raises when encountering
   the unsupported op — it includes the exact ATen name. You can also inspect the exported
   program's graph via ``print(exported_program.graph_module.graph)`` to see all operator
   names.

.. GENERATED FROM PYTHON SOURCE LINES 166-183

Alternative PyTorch import methods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Besides ``from_exported_program``, TVM also provides:

- :py:func:`~tvm.relax.frontend.torch.from_fx`: Works with ``torch.fx.GraphModule``
  from ``torch.fx.symbolic_trace()``. Requires explicit ``input_info`` (shapes and dtypes).
  Use this when ``torch.export`` fails on certain Python control flow patterns.

- :py:func:`~tvm.relax.frontend.torch.relax_dynamo`: A ``torch.compile`` backend that
  compiles and executes the model through TVM in one step. Useful for integrating TVM
  into an existing PyTorch training or inference loop.

- :py:func:`~tvm.relax.frontend.torch.dynamo_capture_subgraphs`: Captures subgraphs from
  a PyTorch model into an IRModule via ``torch.compile``. Each subgraph becomes a separate
  function in the IRModule.

For most use cases, ``from_exported_program`` is the recommended path.

.. GENERATED FROM PYTHON SOURCE LINES 185-191

Verifying the imported model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
After importing, it is good practice to verify that TVM produces the same output as the
original framework. We compile with the minimal ``"zero"`` pipeline (no tuning) and
compare. The same approach applies to models imported via the ONNX and TFLite frontends
shown below.

.. GENERATED FROM PYTHON SOURCE LINES 191-210

.. code-block:: Python


    mod_compiled = relax.get_pipeline("zero")(mod)
    exec_module = tvm.compile(mod_compiled, target="llvm")
    dev = tvm.cpu()
    vm = relax.VirtualMachine(exec_module, dev)

    # Run inference
    input_data = np.random.rand(1, 3, 32, 32).astype("float32")
    tvm_input = tvm.runtime.tensor(input_data, dev)
    tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
    tvm_out = vm["main"](tvm_input, *tvm_params).numpy()

    # Compare with PyTorch
    with torch.no_grad():
        pt_out = torch_model(torch.from_numpy(input_data)).numpy()

    np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5)
    print("PyTorch vs TVM outputs match!")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    PyTorch vs TVM outputs match!




.. GENERATED FROM PYTHON SOURCE LINES 211-218

Importing from ONNX
--------------------
TVM can import ONNX models via :py:func:`~tvm.relax.frontend.onnx.from_onnx`. The
function accepts an ``onnx.ModelProto`` object, so you need to load the model with
``onnx.load()`` first.

Here we export the same CNN model to ONNX format and then import it into TVM.

.. GENERATED FROM PYTHON SOURCE LINES 218-242

.. code-block:: Python


    try:
        import onnx
        import onnxscript  # noqa: F401  # required by torch.onnx.export

        HAS_ONNX = True
    except ImportError:
        onnx = None  # type: ignore[assignment]
        HAS_ONNX = False

    if HAS_ONNX:
        from tvm.relax.frontend.onnx import from_onnx

        # Export the PyTorch model to ONNX
        dummy_input = torch.randn(1, 3, 32, 32)
        onnx_path = "simple_cnn.onnx"
        torch.onnx.export(torch_model, dummy_input, onnx_path, input_names=["input"])

        # Load and import into TVM
        onnx_model = onnx.load(onnx_path)
        mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
        mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
        mod_onnx.show()








.. GENERATED FROM PYTHON SOURCE LINES 243-253

If you already have an ``.onnx`` file on disk, the workflow is even simpler:

.. code-block:: python

   import onnx
   from tvm.relax.frontend.onnx import from_onnx

   onnx_model = onnx.load("my_model.onnx")
   mod = from_onnx(onnx_model)


.. GENERATED FROM PYTHON SOURCE LINES 255-279

Key parameters
~~~~~~~~~~~~~~
- **shape_dict** (dict, optional): Maps input names to shapes. Auto-inferred from the
  model if not provided. Useful when the ONNX model has dynamic dimensions that you
  want to fix to concrete sizes:

  .. code-block:: python

     mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})

- **dtype_dict** (str or dict, default ``"float32"``): Input dtypes. A single string
  applies to all inputs, or use a dict to set per-input dtypes:

  .. code-block:: python

     mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})

- **keep_params_in_input** (bool, default ``False``): Same semantics as PyTorch — whether
  model weights are function parameters or embedded constants.

- **opset** (int, optional): Override the opset version auto-detected from the model.
  Each ONNX op may have different semantics across opset versions; TVM's converter
  selects the appropriate implementation automatically. You rarely need to set this
  unless the model metadata is incorrect.

.. GENERATED FROM PYTHON SOURCE LINES 281-294

Importing from TensorFlow Lite
-------------------------------
TVM can import TFLite flat buffer models via
:py:func:`~tvm.relax.frontend.tflite.from_tflite`. The function expects a TFLite
``Model`` object parsed from flat buffer bytes via ``GetRootAsModel``.

.. note::

   The ``tflite`` Python package has changed its module layout across versions.
   Older versions use ``tflite.Model.Model.GetRootAsModel``, while newer versions use
   ``tflite.Model.GetRootAsModel``. The code below handles both.

Below we create a minimal TFLite model from TensorFlow and import it.

.. GENERATED FROM PYTHON SOURCE LINES 294-334

.. code-block:: Python


    try:
        import tensorflow as tf
        import tflite
        import tflite.Model

        HAS_TFLITE = True
    except ImportError:
        HAS_TFLITE = False

    if HAS_TFLITE:
        from tvm.relax.frontend.tflite import from_tflite

        # Define a simple TF module and convert to TFLite.
        # We use plain TF ops (not keras layers) to avoid variable-handling ops
        # that some TFLite converter versions do not support cleanly.
        class TFModule(tf.Module):
            @tf.function(
                input_signature=[
                    tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
                    tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
                ]
            )
            def forward(self, x, weight):
                return tf.matmul(x, weight) + 0.1

        tf_module = TFModule()
        converter = tf.lite.TFLiteConverter.from_concrete_functions(
            [tf_module.forward.get_concrete_function()], tf_module
        )
        tflite_buf = converter.convert()

        # Parse and import into TVM (API differs between tflite package versions)
        if hasattr(tflite.Model, "Model"):
            tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
        else:
            tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
        mod_tflite = from_tflite(tflite_model)
        mod_tflite.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def main(serving_default_weight_0: R.Tensor((784, 10), dtype="float32"), serving_default_x_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 2, "params": [metadata["ffi.Tensor"][0]]})
            with R.dataflow():
                lv: R.Tensor((10, 784), dtype="float32") = R.permute_dims(serving_default_weight_0, axes=[1, 0])
                lv1: R.Tensor((784, 10), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
                lv2: R.Tensor((1, 10), dtype="float32") = R.matmul(serving_default_x_0, lv1, out_dtype=None)
                gv: R.Tensor((1, 10), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][0])
                R.output(gv)
            return gv

    # Metadata omitted. Use show_meta=True in script() method to show it.





.. GENERATED FROM PYTHON SOURCE LINES 335-353

Loading from a ``.tflite`` file
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you already have a ``.tflite`` file on disk, load the raw bytes and parse them:

.. code-block:: python

   import tflite
   import tflite.Model
   from tvm.relax.frontend.tflite import from_tflite

   with open("my_model.tflite", "rb") as f:
       tflite_buf = f.read()

   if hasattr(tflite.Model, "Model"):
       tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
   else:
       tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
   mod = from_tflite(tflite_model)

.. GENERATED FROM PYTHON SOURCE LINES 355-378

Key parameters
~~~~~~~~~~~~~~
- **shape_dict** / **dtype_dict** (optional): Override input shapes and dtypes. If not
  provided, they are inferred from the TFLite model metadata.

- **op_converter** (class, optional): A custom operator converter class. Subclass
  ``OperatorConverter`` and override its ``convert_map`` dictionary to add or replace
  operator conversions. For example, to add a hypothetical ``CUSTOM_RELU`` op:

  .. code-block:: python

     from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter

     class MyConverter(OperatorConverter):
         def __init__(self, model, subgraph, exp_tab, ctx):
             super().__init__(model, subgraph, exp_tab, ctx)
             self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu

         def _convert_custom_relu(self, op):
             # implement your conversion logic here
             ...

     mod = from_tflite(tflite_model, op_converter=MyConverter)

.. GENERATED FROM PYTHON SOURCE LINES 380-408

Summary
-------

+---------------------+----------------------------+-------------------------------+-----------------------------+
| Aspect              | PyTorch                    | ONNX                          | TFLite                      |
+=====================+============================+===============================+=============================+
| Entry function      | ``from_exported_program``  | ``from_onnx``                 | ``from_tflite``             |
+---------------------+----------------------------+-------------------------------+-----------------------------+
| Input               | ``ExportedProgram``        | ``onnx.ModelProto``           | TFLite ``Model`` object     |
+---------------------+----------------------------+-------------------------------+-----------------------------+
| Custom extension    | ``custom_convert_map``     | —                             | ``op_converter`` class      |
+---------------------+----------------------------+-------------------------------+-----------------------------+

**Which to use?** Pick the frontend that matches your model format:

- Have a PyTorch model? Use ``from_exported_program`` — it has the broadest operator coverage.
- Have an ``.onnx`` file? Use ``from_onnx``.
- Have a ``.tflite`` file? Use ``from_tflite``.

The verification workflow (compile → run → compare) demonstrated in the PyTorch section
above applies equally to ONNX and TFLite imports.

For the full list of supported operators, see the converter map in each frontend's source:
PyTorch uses ``create_convert_map()`` in ``exported_program_translator.py``, ONNX uses
``_get_convert_map()`` in ``onnx_frontend.py``, and TFLite uses ``convert_map`` in
``OperatorConverter`` in ``tflite_frontend.py``.

After importing, refer to :ref:`optimize_model` for optimization and deployment.


.. _sphx_glr_download_how_to_tutorials_import_model.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: import_model.ipynb <import_model.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: import_model.py <import_model.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: import_model.zip <import_model.zip>`
