
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "get_started/tutorials/ir_module.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_get_started_tutorials_ir_module.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_get_started_tutorials_ir_module.py:


.. _ir_module:

IRModule
========
This tutorial presents the core abstraction of Apache TVM, the IRModule.
The IRModule encompasses the **entirety** of the ML models, incorporating the
computational graph, tensor programs, and potential calls to external libraries.

.. contents:: Table of Contents
    :local:
    :depth: 1

.. GENERATED FROM PYTHON SOURCE LINES 32-35

.. code-block:: Python


    import numpy as np








.. GENERATED FROM PYTHON SOURCE LINES 36-40

Create IRModule
---------------
IRModules can be initialized in various ways. We demonstrate a few of them
below.

.. GENERATED FROM PYTHON SOURCE LINES 40-48

.. code-block:: Python

    import torch
    from torch import nn
    from torch.export import export

    import tvm
    from tvm import relax
    from tvm.relax.frontend.torch import from_exported_program








.. GENERATED FROM PYTHON SOURCE LINES 49-55

Import from existing models
~~~~~~~~~~~~~~~~~~~~~~~~~~~
The most common way to initialize an IRModule is to import from an existing
model. Apache TVM accommodates imports from a range of frameworks,
such as PyTorch and ONNX. This tutorial solely demonstrates the import process
from PyTorch.

.. GENERATED FROM PYTHON SOURCE LINES 55-86

.. code-block:: Python



    # Create a dummy model
    class TorchModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 256)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(256, 10)

        def forward(self, x):
            x = self.fc1(x)
            x = self.relu1(x)
            x = self.fc2(x)
            return x


    # Give an example argument to torch.export
    example_args = (torch.randn(1, 784, dtype=torch.float32),)

    # Convert the model to IRModule
    with torch.no_grad():
        exported_program = export(TorchModel().eval(), example_args)
        mod_from_torch = from_exported_program(
            exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
        )

    mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
    # Print the IRModule
    mod_from_torch.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /opt/uv/python/cpython-3.10-linux-x86_64-gnu/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    # from tvm.script import ir as I
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=[1, 0])
                lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
                lv2: R.Tensor((1, 256), dtype="float32") = R.add(p_fc1_bias, lv1)
                lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
                lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=[1, 0])
                lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
                lv6: R.Tensor((1, 10), dtype="float32") = R.add(p_fc2_bias, lv5)
                gv: R.Tensor((1, 10), dtype="float32") = lv6
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 87-91

Write with Relax NN Module
~~~~~~~~~~~~~~~~~~~~~~~~~~
Apache TVM also provides a set of PyTorch-liked APIs, to help users
write the IRModule directly.

.. GENERATED FROM PYTHON SOURCE LINES 91-114

.. code-block:: Python


    from tvm.relax.frontend import nn


    class RelaxModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 256)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(256, 10)

        def forward(self, x):
            x = self.fc1(x)
            x = self.relu1(x)
            x = self.fc2(x)
            return x


    mod_from_relax, params_from_relax = RelaxModel().export_tvm(
        {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
    )
    mod_from_relax.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
                matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype=None)
                add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
                relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
                permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
                matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype=None)
                add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
                gv: R.Tensor((1, 10), dtype="float32") = add1
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 115-120

Create via TVMScript
~~~~~~~~~~~~~~~~~~~~
TVMScript is a Python-based DSL for IRModules. We are able to
directly output the IRModule in the TVMScript syntax, or alternatively,
parse the TVMScript to obtain an IRModule.

.. GENERATED FROM PYTHON SOURCE LINES 120-152

.. code-block:: Python


    from tvm.script import ir as I
    from tvm.script import relax as R


    @I.ir_module
    class TVMScriptModule:
        @R.function
        def main(
            x: R.Tensor((1, 784), dtype="float32"),
            fc1_weight: R.Tensor((256, 784), dtype="float32"),
            fc1_bias: R.Tensor((256,), dtype="float32"),
            fc2_weight: R.Tensor((10, 256), dtype="float32"),
            fc2_bias: R.Tensor((10,), dtype="float32"),
        ) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                permute_dims = R.permute_dims(fc1_weight, axes=None)
                matmul = R.matmul(x, permute_dims, out_dtype=None)
                add = R.add(matmul, fc1_bias)
                relu = R.nn.relu(add)
                permute_dims1 = R.permute_dims(fc2_weight, axes=None)
                matmul1 = R.matmul(relu, permute_dims1, out_dtype=None)
                add1 = R.add(matmul1, fc2_bias)
                gv = add1
                R.output(gv)
            return gv


    mod_from_script = TVMScriptModule
    mod_from_script.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
                matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype=None)
                add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
                relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
                permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
                matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype=None)
                add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
                gv: R.Tensor((1, 10), dtype="float32") = add1
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 153-156

Attributes of an IRModule
-------------------------
An IRModule is a collection of functions, indexed by GlobalVars.

.. GENERATED FROM PYTHON SOURCE LINES 156-160

.. code-block:: Python


    mod = mod_from_torch
    print(mod.get_global_vars())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    (I.GlobalVar("main"),)




.. GENERATED FROM PYTHON SOURCE LINES 161-163

We can access the functions in the IRModule by indexing with the GlobalVars
or their names

.. GENERATED FROM PYTHON SOURCE LINES 163-170

.. code-block:: Python


    # index by global var name
    print(mod["main"])
    # index by global var, and checking they are the same function
    (gv,) = mod.get_global_vars()
    assert mod[gv] == mod["main"]





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import relax as R

    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=[1, 0])
            lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
            lv2: R.Tensor((1, 256), dtype="float32") = R.add(p_fc1_bias, lv1)
            lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=[1, 0])
            lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
            lv6: R.Tensor((1, 10), dtype="float32") = R.add(p_fc2_bias, lv5)
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv




.. GENERATED FROM PYTHON SOURCE LINES 171-181

Transformations on IRModules
----------------------------
Transformations are the import component of Apache TVM. One transformation
takes in an IRModule and outputs another IRModule. We can apply a sequence of
transformations to an IRModule to obtain a new IRModule. That is the common way to
optimize a model.

In this getting started tutorial, we only demonstrate how to apply transformations
to an IRModule. For details of each transformation, please refer to the
:ref:`Transformation API Reference <api-relax-transformation>`

.. GENERATED FROM PYTHON SOURCE LINES 183-186

We first apply **LegalizeOps** transformation to the IRModule. This transformation
will convert the Relax module into a mixed stage, with both Relax and TensorIR function
within the same module. Meanwhile, the Relax operators will be converted into ``call_tir``.

.. GENERATED FROM PYTHON SOURCE LINES 186-191

.. code-block:: Python


    mod = mod_from_torch
    mod = relax.transform.LegalizeOps()(mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def add(p_fc1_bias: T.Buffer((T.int64(256),), "float32"), lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc1_bias[v_ax1], lv1[v_ax0, v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = p_fc1_bias[v_ax1] + lv1[v_ax0, v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def add1(p_fc2_bias: T.Buffer((T.int64(10),), "float32"), lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc2_bias[v_ax1], lv5[v_ax0, v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = p_fc2_bias[v_ax1] + lv5[v_ax0, v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                    T.writes(matmul[v_i0, v_i1])
                    with T.init():
                        matmul[v_i0, v_i1] = T.float32(0.0)
                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]

        @T.prim_func(private=True, s_tir=True)
        def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                    T.writes(matmul[v_i0, v_i1])
                    with T.init():
                        matmul[v_i0, v_i1] = T.float32(0.0)
                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]

        @T.prim_func(private=True, s_tir=True)
        def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for i0, i1 in T.grid(T.int64(1), T.int64(256)):
                with T.sblock("compute"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(lv2[v_i0, v_i1])
                    T.writes(compute[v_i0, v_i1])
                    compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc1_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_ty=R.Tensor((784, 256), dtype="float32"))
                lv1 = R.call_tir(cls.matmul, (x, lv), out_ty=R.Tensor((1, 256), dtype="float32"))
                lv2 = R.call_tir(cls.add, (p_fc1_bias, lv1), out_ty=R.Tensor((1, 256), dtype="float32"))
                lv3 = R.call_tir(cls.relu, (lv2,), out_ty=R.Tensor((1, 256), dtype="float32"))
                lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_ty=R.Tensor((256, 10), dtype="float32"))
                lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_ty=R.Tensor((1, 10), dtype="float32"))
                lv6 = R.call_tir(cls.add1, (p_fc2_bias, lv5), out_ty=R.Tensor((1, 10), dtype="float32"))
                gv: R.Tensor((1, 10), dtype="float32") = lv6
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 192-194

After the transformation, there are much more functions inside the module. Let's print
the global vars again.

.. GENERATED FROM PYTHON SOURCE LINES 194-197

.. code-block:: Python


    print(mod.get_global_vars())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    (I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1"))




.. GENERATED FROM PYTHON SOURCE LINES 198-220

Next, Apache TVM provides a set of default transformation pipelines for users,
to simplify the transformation process. We can then apply the default pipeline to the module.
The default **zero** pipeline contains very fundamental transformations, including:

- **LegalizeOps**: This transform converts the Relax operators into `call_tir` functions
  with the corresponding TensorIR Functions. After this transform, the IRModule will
  contain both Relax functions and TensorIR functions.
- **AnnotateTIROpPattern**: This transform annotates the pattern of the TensorIR functions,
  preparing them for subsequent operator fusion.
- **FoldConstant**: This pass performs constant folding, optimizing operations
  involving constants.
- **FuseOps and FuseTIR**: These two passes work together to fuse operators based on the
  patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform
  both Relax functions and TensorIR functions.

.. note::

  Here, we have applied **LegalizeOps** twice in the flow. The second time is useless but
  harmless.

  Every passes can be duplicated in the flow, since we ensure the passes can handle all legal
  IRModule inputs. This design can help users to construct their own pipeline.

.. GENERATED FROM PYTHON SOURCE LINES 220-224

.. code-block:: Python


    mod = relax.get_pipeline("zero")(mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(10)))
            for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                    T.writes(matmul_intermediate[v_i0, v_i1])
                    with T.init():
                        matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                    matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
            for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc2_bias[v_ax1], matmul_intermediate[v_ax0, v_ax1])
                    T.writes(T_add_intermediate[v_ax0, v_ax1])
                    T_add_intermediate[v_ax0, v_ax1] = p_fc2_bias[v_ax1] + matmul_intermediate[v_ax0, v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(256)))
            T_add_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(256)))
            for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                    T.writes(matmul_intermediate[v_i0, v_i1])
                    with T.init():
                        matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                    matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
            for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc1_bias[v_ax1], matmul_intermediate[v_ax0, v_ax1])
                    T.writes(T_add_intermediate[v_ax0, v_ax1])
                    T_add_intermediate[v_ax0, v_ax1] = p_fc1_bias[v_ax1] + matmul_intermediate[v_ax0, v_ax1]
            for i0, i1 in T.grid(T.int64(1), T.int64(256)):
                with T.sblock("compute"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(T_add_intermediate[v_i0, v_i1])
                    T.writes(compute_intermediate[v_i0, v_i1])
                    compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc1_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(p_fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_ty=R.Tensor((784, 256), dtype="float32"))
                lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_ty=R.Tensor((1, 256), dtype="float32"))
                lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_ty=R.Tensor((256, 10), dtype="float32"))
                gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_ty=R.Tensor((1, 10), dtype="float32"))
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 225-235

Deploy the IRModule Universally
-------------------------------
After the optimization, we can compile the model into a TVM runtime module.
Notably, Apache TVM provides the ability of universal deployment, which means
we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging
backends.

Deploy on CPU
~~~~~~~~~~~~~
We can deploy the IRModule on CPU by specifying the target as ``llvm``.

.. GENERATED FROM PYTHON SOURCE LINES 235-245

.. code-block:: Python


    exec = tvm.compile(mod, target="llvm")
    dev = tvm.cpu()
    vm = relax.VirtualMachine(exec, dev)

    raw_data = np.random.rand(1, 784).astype("float32")
    data = tvm.runtime.tensor(raw_data, dev)
    cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
    print(cpu_out)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    [[-0.10166795  0.03385016 -0.14397773  0.12001663 -0.0545135  -0.14803053
      -0.14423272  0.04252397  0.04188291 -0.27159798]]




.. GENERATED FROM PYTHON SOURCE LINES 246-255

Deploy on GPU
~~~~~~~~~~~~~
Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires
programs containing extra information, such as the thread bindings and shared memory
allocations. We need a further transformation to generate the GPU programs.

We use ``DLight`` to generate the GPU programs. In this tutorial, we won't go into
the details of ``DLight``.


.. GENERATED FROM PYTHON SOURCE LINES 255-264

.. code-block:: Python


    from tvm.s_tir import dlight as dl

    with tvm.target.Target("cuda"):
        gpu_mod = dl.ApplyDefaultSchedule(
            dl.gpu.Matmul(),
            dl.gpu.Fallback(),
        )(mod)








.. GENERATED FROM PYTHON SOURCE LINES 265-266

Now we can compile the IRModule on GPU, the similar way as we did on CPU.

.. GENERATED FROM PYTHON SOURCE LINES 266-279

.. code-block:: Python


    exec = tvm.compile(gpu_mod, target="cuda")
    dev = tvm.device("cuda", 0)
    vm = relax.VirtualMachine(exec, dev)
    # Need to allocate data and params on GPU device
    data = tvm.runtime.tensor(raw_data, dev)
    gpu_params = [tvm.runtime.tensor(p, dev) for p in params_from_torch["main"]]
    gpu_out = vm["main"](data, *gpu_params).numpy()
    print(gpu_out)

    # Check the correctness of the results
    assert np.allclose(cpu_out, gpu_out, atol=1e-3)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    [[-0.10166799  0.03385017 -0.14397773  0.1200166  -0.05451344 -0.14803056
      -0.14423269  0.04252395  0.04188287 -0.27159795]]




.. GENERATED FROM PYTHON SOURCE LINES 280-286

Deploy on Other Backends
~~~~~~~~~~~~~~~~~~~~~~~~
Apache TVM also supports other backends, such as different kinds of GPUs
(Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other
emerging backends (e.g., WebAssembly). The deployment process is similar to the
GPU backend.


.. _sphx_glr_download_get_started_tutorials_ir_module.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: ir_module.ipynb <ir_module.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: ir_module.py <ir_module.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: ir_module.zip <ir_module.zip>`
