
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "deep_dive/tensor_ir/tutorials/tir_transformation.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_deep_dive_tensor_ir_tutorials_tir_transformation.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_deep_dive_tensor_ir_tutorials_tir_transformation.py:


.. _tirx-transform:

Transformation
--------------
In this section, we will get to the main ingredients of the compilation flows -
transformations of primitive tensor functions.

.. GENERATED FROM PYTHON SOURCE LINES 29-38

In the :ref:`previous section <tirx-learning>`, we have given an example of how to write
``mm_relu`` using TensorIR. In practice, there can be multiple ways to implement
the same functionality, and each implementation can result in different performance.

.. note::
  This tutorial primarily illustrates the application of TensorIR Transformation,
  rather than delving into optimization techniques.

First, let's take a look at the implementation of ``mm_relu`` in the previous section:

.. GENERATED FROM PYTHON SOURCE LINES 38-69

.. code-block:: Python


    import tvm
    from tvm.script import ir as I
    from tvm.script import tirx as T


    @I.ir_module
    class MyModule:
        @T.prim_func(s_tir=True)
        def main(
            A: T.Buffer((128, 128), "float32"),
            B: T.Buffer((128, 128), "float32"),
            C: T.Buffer((128, 128), "float32"),
        ):
            T.func_attr({"tirx.noalias": True})
            with T.sblock("root"):
                T.reads()
                T.writes()
                Y = T.sblock_alloc_buffer((128, 128))
                for i, j, k in T.grid(128, 128, 128):
                    with T.sblock("Y"):
                        vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                        with T.init():
                            Y[vi, vj] = T.float32(0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
                for i, j in T.grid(128, 128):
                    with T.sblock("C"):
                        vi, vj = T.axis.remap("SS", [i, j])
                        C[vi, vj] = T.max(Y[vi, vj], T.float32(0))









.. GENERATED FROM PYTHON SOURCE LINES 70-72

Before we transform the function, let's first evaluate the performance of the
original implementation.

.. GENERATED FROM PYTHON SOURCE LINES 72-96

.. code-block:: Python


    import numpy as np

    a_np = np.random.uniform(size=(128, 128)).astype("float32")
    b_np = np.random.uniform(size=(128, 128)).astype("float32")
    c_np = a_np @ b_np

    a_nd = tvm.runtime.tensor(a_np)
    b_nd = tvm.runtime.tensor(b_np)
    c_nd = tvm.runtime.tensor(np.zeros((128, 128), dtype="float32"))


    def evaluate(mod: tvm.IRModule):
        lib = tvm.tirx.build(mod, target="llvm")
        # check correctness
        lib(a_nd, b_nd, c_nd)
        np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
        # evaluate performance
        f_timer = lib.time_evaluator("main", tvm.cpu())
        print(f_timer(a_nd, b_nd, c_nd))


    evaluate(MyModule)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Execution time summary:
     mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
       2.7644       2.7644       2.7644       2.7644       0.0000                  




.. GENERATED FROM PYTHON SOURCE LINES 97-101

Initialization Schedule
***********************
We initiate the process of code transformation by establishing a Schedule helper class,
utilizing the provided **MyModule** as input.

.. GENERATED FROM PYTHON SOURCE LINES 101-104

.. code-block:: Python


    sch = tvm.s_tir.Schedule(MyModule)








.. GENERATED FROM PYTHON SOURCE LINES 105-109

Loop Tiling
***********
Subsequently, we execute the requisite operations to acquire a reference to
block **Y** and its associated loops.

.. GENERATED FROM PYTHON SOURCE LINES 109-113

.. code-block:: Python


    block_Y = sch.get_sblock("Y")
    i, j, k = sch.get_loops(block_Y)








.. GENERATED FROM PYTHON SOURCE LINES 114-119

We now proceed to execute the transformations. The initial modification involves
splitting loop ``j`` into two separate loops, with the inner loop possessing a
length of 8. It is crucial to understand that the transformation process is procedural;
thus, inadvertent execution of the block twice will yield an error stating the
non-existence of variable ``j``.

.. GENERATED FROM PYTHON SOURCE LINES 119-122

.. code-block:: Python


    j0, j1 = sch.split(j, factors=[None, 8])








.. GENERATED FROM PYTHON SOURCE LINES 123-124

The outcome of the transformation can be examined, as it is retained within ``sch.mod``.

.. GENERATED FROM PYTHON SOURCE LINES 124-127

.. code-block:: Python


    sch.mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis

    @I.ir_module
    class Module:
        @T.prim_func(s_tir=True)
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            Y = T.sblock_alloc_buffer((128, 128))
            for i, j_0, j_1, k in T.grid(128, 16, 8, 128):
                with T.sblock("Y"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0.0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for i, j in T.grid(128, 128):
                with T.sblock("C"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))





.. GENERATED FROM PYTHON SOURCE LINES 128-131

Following the initial transformation phase, two supplementary loops, ``j_0`` and ``j_1``,
have been generated with respective ranges of 16 and 8. The subsequent
action involves reordering these two loops.

.. GENERATED FROM PYTHON SOURCE LINES 131-136

.. code-block:: Python


    sch.reorder(j0, k, j1)
    sch.mod.show()
    evaluate(sch.mod)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis

    @I.ir_module
    class Module:
        @T.prim_func(s_tir=True)
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            Y = T.sblock_alloc_buffer((128, 128))
            for i, j_0, k, j_1 in T.grid(128, 16, 128, 8):
                with T.sblock("Y"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0.0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for i, j in T.grid(128, 128):
                with T.sblock("C"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

    Execution time summary:
     mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
       0.8633       0.8633       0.8633       0.8633       0.0000                  




.. GENERATED FROM PYTHON SOURCE LINES 137-142

Leverage Localities
*******************
Subsequently, we will execute two additional transformation steps to achieve a different
variant. First, we employ a primitive known as **reverse_compute_at** to relocate block
**C** to an inner loop of **Y**.

.. GENERATED FROM PYTHON SOURCE LINES 142-147

.. code-block:: Python


    block_C = sch.get_sblock("C")
    sch.reverse_compute_at(block_C, j0)
    sch.mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis

    @I.ir_module
    class Module:
        @T.prim_func(s_tir=True)
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            Y = T.sblock_alloc_buffer((128, 128))
            for i, j_0 in T.grid(128, 16):
                for k, j_1 in T.grid(128, 8):
                    with T.sblock("Y"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        with T.init():
                            Y[vi, vj] = T.float32(0.0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
                for ax0 in range(8):
                    with T.sblock("C"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + ax0)
                        T.reads(Y[vi, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))





.. GENERATED FROM PYTHON SOURCE LINES 148-157

Rewrite Reduction
*****************
Until now, the reduction initialization and update step have been maintained together
within a single block body. This amalgamated form facilitates loop transformations,
as the outer loops ``i``, ``j`` of initialization and updates generally need to remain
synchronized.

Following the loop transformations, we can segregate the initialization of Y's elements
from the reduction update via the **decompose_reduction** primitive.

.. GENERATED FROM PYTHON SOURCE LINES 157-162

.. code-block:: Python


    sch.decompose_reduction(block_Y, k)
    sch.mod.show()
    evaluate(sch.mod)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis

    @I.ir_module
    class Module:
        @T.prim_func(s_tir=True)
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            Y = T.sblock_alloc_buffer((128, 128))
            for i, j_0 in T.grid(128, 16):
                for j_1_init in range(8):
                    with T.sblock("Y_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                        T.reads()
                        T.writes(Y[vi, vj])
                        Y[vi, vj] = T.float32(0.0)
                for k, j_1 in T.grid(128, 8):
                    with T.sblock("Y_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
                for ax0 in range(8):
                    with T.sblock("C"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + ax0)
                        T.reads(Y[vi, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

    Execution time summary:
     mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
       0.3396       0.3396       0.3396       0.3396       0.0000                  




.. GENERATED FROM PYTHON SOURCE LINES 163-171

Trace the Transformation
************************
TensorIR schedule is a procedural language, and the transformation is executed in a
step-by-step manner. We can trace the transformation by printing the schedule or the
history of the schedule.

We've already see the schedule by printing ``sch.mod``. We can also print the history
of the schedule by ``sch.trace``.

.. GENERATED FROM PYTHON SOURCE LINES 171-174

.. code-block:: Python


    sch.trace.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm import s_tir
    def apply_trace(sch: s_tir.Schedule) -> None:
      b0 = sch.get_sblock(name="Y", func_name="main")
      l1, l2, l3 = sch.get_loops(block=b0)
      l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
      sch.reorder(l4, l3, l5)
      b6 = sch.get_sblock(name="C", func_name="main")
      sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
      b7 = sch.decompose_reduction(block=b0, loop=l3)





.. GENERATED FROM PYTHON SOURCE LINES 175-176

Alternatively, we can output the IRModule in conjunction with the historical trace.

.. GENERATED FROM PYTHON SOURCE LINES 176-178

.. code-block:: Python


    sch.show()




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis

    @I.ir_module
    class Module:
        @T.prim_func(s_tir=True)
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"tirx.noalias": True})
            # with T.sblock("root"):
            Y = T.sblock_alloc_buffer((128, 128))
            for i, j_0 in T.grid(128, 16):
                for j_1_init in range(8):
                    with T.sblock("Y_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                        T.reads()
                        T.writes(Y[vi, vj])
                        Y[vi, vj] = T.float32(0.0)
                for k, j_1 in T.grid(128, 8):
                    with T.sblock("Y_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
                for ax0 in range(8):
                    with T.sblock("C"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + ax0)
                        T.reads(Y[vi, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

    # from tvm import s_tir
    def apply_trace(sch: s_tir.Schedule) -> None:
      b0 = sch.get_sblock(name="Y", func_name="main")
      l1, l2, l3 = sch.get_loops(block=b0)
      l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
      sch.reorder(l4, l3, l5)
      b6 = sch.get_sblock(name="C", func_name="main")
      sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
      b7 = sch.decompose_reduction(block=b0, loop=l3)






.. _sphx_glr_download_deep_dive_tensor_ir_tutorials_tir_transformation.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: tir_transformation.ipynb <tir_transformation.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: tir_transformation.py <tir_transformation.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: tir_transformation.zip <tir_transformation.zip>`
