
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "deep_dive/relax/tutorials/relax_transformation.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_deep_dive_relax_tutorials_relax_transformation.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_deep_dive_relax_tutorials_relax_transformation.py:


.. _relax-transform:

Transformation
--------------
In this section, we will dive into the transformation of Relax programs.
Transformations is one of the key ingredients of the compilation flows
for optimizing and integrating with hardware backends.

.. GENERATED FROM PYTHON SOURCE LINES 30-32

Let's first create a simple Relax program as what we have done in
the :ref:`previous section <relax-creation>`.

.. GENERATED FROM PYTHON SOURCE LINES 32-57

.. code-block:: Python


    import tvm
    from tvm import IRModule, relax
    from tvm.relax.frontend import nn


    class NNModule(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 128)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.fc1(x)
            x = self.relu1(x)
            x = self.fc2(x)
            return x


    origin_mod, params = NNModule().export_tvm(
        {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
    )
    origin_mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
            n = T.int64()
            R.func_attr({"num_input": 1})
            with R.dataflow():
                permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
                matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype=None)
                add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
                relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
                permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
                matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype=None)
                add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
                gv: R.Tensor((n, 10), dtype="float32") = add1
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 58-64

Apply transformations
~~~~~~~~~~~~~~~~~~~~~
Passes are the main way to apply transformations to the program.
We can apply passes to the program. As first step, let's apply
a built-in pass ``LegalizeOps`` to lower the high-level operators
into low-level operators.

.. GENERATED FROM PYTHON SOURCE LINES 64-68

.. code-block:: Python


    mod = tvm.relax.transform.LegalizeOps()(origin_mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
            T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(n, T.int64(128)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
            T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(n, T.int64(10)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            x = T.match_buffer(var_x, (n, T.int64(784)))
            matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
            # with T.sblock("root"):
            for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                    T.writes(matmul[v_i0, v_i1])
                    with T.init():
                        matmul[v_i0, v_i1] = T.float32(0.0)
                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]

        @T.prim_func(private=True, s_tir=True)
        def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            relu = T.match_buffer(var_relu, (n, T.int64(128)))
            matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
            # with T.sblock("root"):
            for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                    T.writes(matmul[v_i0, v_i1])
                    with T.init():
                        matmul[v_i0, v_i1] = T.float32(0.0)
                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]

        @T.prim_func(private=True, s_tir=True)
        def relu(var_add: T.handle, var_compute: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            add = T.match_buffer(var_add, (n, T.int64(128)))
            compute = T.match_buffer(var_compute, (n, T.int64(128)))
            # with T.sblock("root"):
            for i0, i1 in T.grid(n, T.int64(128)):
                with T.sblock("compute"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(add[v_i0, v_i1])
                    T.writes(compute[v_i0, v_i1])
                    compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(fc1_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

        @R.function
        def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
            n = T.int64()
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_ty=R.Tensor((784, 128), dtype="float32"))
                matmul = R.call_tir(cls.matmul, (x, permute_dims), out_ty=R.Tensor((n, 128), dtype="float32"))
                add = R.call_tir(cls.add, (matmul, fc1_bias), out_ty=R.Tensor((n, 128), dtype="float32"))
                relu = R.call_tir(cls.relu, (add,), out_ty=R.Tensor((n, 128), dtype="float32"))
                permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_ty=R.Tensor((128, 10), dtype="float32"))
                matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_ty=R.Tensor((n, 10), dtype="float32"))
                add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_ty=R.Tensor((n, 10), dtype="float32"))
                gv: R.Tensor((n, 10), dtype="float32") = add1
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 69-75

As we can see from the output, the high-level operators (aka ``relax.op``) in the program
are replaced by their corresponding low-level operators (aka ``relax.call_tir``).

Then let's trying to apply the operator fusion, which is a wide-used optimization technique
in ML compilers. Note that in relax, fusion optimizations are done with the collaboration of
a set of passes. We can apply them in a sequence.

.. GENERATED FROM PYTHON SOURCE LINES 75-85

.. code-block:: Python


    mod = tvm.ir.transform.Sequential(
        [
            tvm.relax.transform.AnnotateTIROpPattern(),
            tvm.relax.transform.FuseOps(),
            tvm.relax.transform.FuseTIR(),
        ]
    )(mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            relu = T.match_buffer(p_relu, (n, T.int64(128)))
            T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
            # with T.sblock("root"):
            matmul_intermediate = T.sblock_alloc_buffer((n, T.int64(10)))
            for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                    T.writes(matmul_intermediate[v_i0, v_i1])
                    with T.init():
                        matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                    matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
            for ax0, ax1 in T.grid(n, T.int64(10)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
                    T.writes(T_add_intermediate[v_ax0, v_ax1])
                    T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]

        @T.prim_func(private=True, s_tir=True)
        def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
            T.func_attr({"tirx.noalias": True})
            n = T.int64()
            x = T.match_buffer(p_x, (n, T.int64(784)))
            compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
            # with T.sblock("root"):
            matmul_intermediate = T.sblock_alloc_buffer((n, T.int64(128)))
            T_add_intermediate = T.sblock_alloc_buffer((n, T.int64(128)))
            for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
                with T.sblock("matmul"):
                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                    T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                    T.writes(matmul_intermediate[v_i0, v_i1])
                    with T.init():
                        matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                    matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
            for ax0, ax1 in T.grid(n, T.int64(128)):
                with T.sblock("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
                    T.writes(T_add_intermediate[v_ax0, v_ax1])
                    T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
            for i0, i1 in T.grid(n, T.int64(128)):
                with T.sblock("compute"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(T_add_intermediate[v_i0, v_i1])
                    T.writes(compute_intermediate[v_i0, v_i1])
                    compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(fc1_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
                with T.sblock("T_transpose"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

        @R.function
        def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
            n = T.int64()
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_ty=R.Tensor((784, 128), dtype="float32"))
                lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_ty=R.Tensor((n, 128), dtype="float32"))
                permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_ty=R.Tensor((128, 10), dtype="float32"))
                gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_ty=R.Tensor((n, 10), dtype="float32"))
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 86-97

As result, we can see that the ``matmul``, ``add`` and ``relu`` operators are fused
into one kernel (aka one ``call_tir``).

For all built-in passes, please refer to :py:class:`relax.transform`.

Custom Passes
~~~~~~~~~~~~~
We can also define our own passes. Let's take an example of rewriting the ``relu``
operator to ``gelu`` operator.

First, we need to write a Relax IR Mutator to do the rewriting.

.. GENERATED FROM PYTHON SOURCE LINES 97-114

.. code-block:: Python


    from tvm.relax.expr_functor import PyExprMutator, mutator


    @mutator
    class ReluRewriter(PyExprMutator):
        def __init__(self, mod):
            super().__init__(mod)

        def visit_call_(self, call: relax.Call) -> relax.Expr:
            # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
            if call.op.name == "relax.nn.relu":
                return relax.op.nn.gelu(call.args[0])

            return super().visit_call_(call)









.. GENERATED FROM PYTHON SOURCE LINES 115-116

Then we can write a pass to apply the mutator to the whole module.

.. GENERATED FROM PYTHON SOURCE LINES 116-133

.. code-block:: Python



    @tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
    class ReluToGelu:  # pylint: disable=too-few-public-methods
        def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
            """IRModule-level transformation"""
            rewriter = ReluRewriter(mod)
            for g_var, func in mod.functions_items():
                if isinstance(func, relax.Function):
                    func = rewriter.visit_expr(func)
                    rewriter.builder_.update_func(g_var, func)
            return rewriter.builder_.get()


    mod = ReluToGelu()(origin_mod)
    mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
            n = T.int64()
            R.func_attr({"num_input": 1})
            with R.dataflow():
                permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
                matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype=None)
                add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
                relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
                permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
                matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype=None)
                add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
                gv: R.Tensor((n, 10), dtype="float32") = add1
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 134-143

The printed output shows that the ``relax.nn.relu`` operator is
rewritten to ``relax.nn.gelu`` operator.

For the details of the mutator, please refer to :py:class:`relax.expr_functor.PyExprMutator`.

Summary
~~~~~~~
In this section, we have shown how to apply transformations to the Relax program.
We have also shown how to define and apply custom transformations.


.. _sphx_glr_download_deep_dive_relax_tutorials_relax_transformation.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: relax_transformation.ipynb <relax_transformation.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: relax_transformation.py <relax_transformation.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: relax_transformation.zip <relax_transformation.zip>`
