# Transformation¶

In this section, we will get to the main ingredients of the compilation flows - transformations of primitive tensor functions.

In the previous section, we have given an example of how to write `mm_relu` using TensorIR. In practice, there can be multiple ways to implement the same functionality, and each implementation can result in different performance.

Note

This tutorial primarily illustrates the application of TensorIR Transformation, rather than delving into optimization techniques.

First, let’s take a look at the implementation of `mm_relu` in the previous section:

```import tvm
from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class MyModule:
@T.prim_func
def main(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
```

Before we transform the function, let’s first evaluate the performance of the original implementation.

```import numpy as np

a_np = np.random.uniform(size=(128, 128)).astype("float32")
b_np = np.random.uniform(size=(128, 128)).astype("float32")
c_np = a_np @ b_np

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32"))

def evaluate(mod: tvm.IRModule):
lib = tvm.build(mod, target="llvm")
# check correctness
lib(a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
# evaluate performance
f_timer = lib.time_evaluator("main", tvm.cpu())
print(f_timer(a_nd, b_nd, c_nd))

evaluate(MyModule)
```
```Execution time summary:
mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
2.2397       2.2397       2.2397       2.2397       0.0000
```

## Initialization Schedule¶

We initiate the process of code transformation by establishing a Schedule helper class, utilizing the provided MyModule as input.

```sch = tvm.tir.Schedule(MyModule)
```

## Loop Tiling¶

Subsequently, we execute the requisite operations to acquire a reference to block Y and its associated loops.

```block_Y = sch.get_block("Y")
i, j, k = sch.get_loops(block_Y)
```

We now proceed to execute the transformations. The initial modification involves splitting loop `j` into two separate loops, with the inner loop possessing a length of 4. It is crucial to understand that the transformation process is procedural; thus, inadvertent execution of the block twice will yield an error stating the non-existence of variable `j`.

```j0, j1 = sch.split(j, factors=[None, 8])
```

The outcome of the transformation can be examined, as it is retained within `sch.mod`.

```sch.mod.show()
```
```# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, j_1, k in T.grid(128, 16, 8, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
```

Following the initial transformation phase, two supplementary loops, `j_0` and `j_1`, have been generated with respective ranges of 32 and 4. The subsequent action involves reordering these two loops.

```sch.reorder(j0, k, j1)
sch.mod.show()
evaluate(sch.mod)
```
```# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, k, j_1 in T.grid(128, 16, 128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

Execution time summary:
mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
0.8570       0.8570       0.8570       0.8570       0.0000
```

## Leverage Localities¶

Subsequently, we will execute two additional transformation steps to achieve a different variant. First, we employ a primitive known as reverse_compute_at to relocate block C to an inner loop of Y.

```block_C = sch.get_block("C")
sch.reverse_compute_at(block_C, j0)
sch.mod.show()
```
```# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for k, j_1 in T.grid(128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
```

## Rewrite Reduction¶

Until now, the reduction initialization and update step have been maintained together within a single block body. This amalgamated form facilitates loop transformations, as the outer loops `i`, `j` of initialization and updates generally need to remain synchronized.

Following the loop transformations, we can segregate the initialization of Y’s elements from the reduction update via the decompose_reduction primitive.

```sch.decompose_reduction(block_Y, k)
sch.mod.show()
evaluate(sch.mod)
```
```# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

Execution time summary:
mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
0.3215       0.3215       0.3215       0.3215       0.0000
```

## Trace the Transformation¶

TensorIR schedule is a procedural language, and the transformation is executed in a step-by-step manner. We can trace the transformation by printing the schedule or the history of the schedule.

We’ve already see the schedule by printing `sch.mod`. We can also print the history of the schedule by `sch.trace`.

```sch.trace.show()
```
```# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)
```

Alternatively, we can output the IRModule in conjunction with the historical trace.

```sch.show()
```
```# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)
```

Gallery generated by Sphinx-Gallery