Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Transformation
In this section, we will get to the main ingredients of the compilation flows - transformations of primitive tensor functions.
In the previous section, we have given an example of how to write
mm_relu
using TensorIR. In practice, there can be multiple ways to implement
the same functionality, and each implementation can result in different performance.
Note
This tutorial primarily illustrates the application of TensorIR Transformation, rather than delving into optimization techniques.
First, let’s take a look at the implementation of mm_relu
in the previous section:
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class MyModule:
@T.prim_func
def main(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
Before we transform the function, let’s first evaluate the performance of the original implementation.
import numpy as np
a_np = np.random.uniform(size=(128, 128)).astype("float32")
b_np = np.random.uniform(size=(128, 128)).astype("float32")
c_np = a_np @ b_np
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
def evaluate(mod: tvm.IRModule):
lib = tvm.build(mod, target="llvm")
# check correctness
lib(a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
# evaluate performance
f_timer = lib.time_evaluator("main", tvm.cpu())
print(f_timer(a_nd, b_nd, c_nd))
evaluate(MyModule)
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
2.3157 2.3157 2.3157 2.3157 0.0000
Initialization Schedule
We initiate the process of code transformation by establishing a Schedule helper class, utilizing the provided MyModule as input.
sch = tvm.tir.Schedule(MyModule)
Loop Tiling
Subsequently, we execute the requisite operations to acquire a reference to block Y and its associated loops.
We now proceed to execute the transformations. The initial modification involves
splitting loop j
into two separate loops, with the inner loop possessing a
length of 4. It is crucial to understand that the transformation process is procedural;
thus, inadvertent execution of the block twice will yield an error stating the
non-existence of variable j
.
The outcome of the transformation can be examined, as it is retained within sch.mod
.
sch.mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, j_1, k in T.grid(128, 16, 8, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Following the initial transformation phase, two supplementary loops, j_0
and j_1
,
have been generated with respective ranges of 32 and 4. The subsequent
action involves reordering these two loops.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, k, j_1 in T.grid(128, 16, 128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.9731 0.9731 0.9731 0.9731 0.0000
Leverage Localities
Subsequently, we will execute two additional transformation steps to achieve a different variant. First, we employ a primitive known as reverse_compute_at to relocate block C to an inner loop of Y.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for k, j_1 in T.grid(128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Rewrite Reduction
Until now, the reduction initialization and update step have been maintained together
within a single block body. This amalgamated form facilitates loop transformations,
as the outer loops i
, j
of initialization and updates generally need to remain
synchronized.
Following the loop transformations, we can segregate the initialization of Y’s elements from the reduction update via the decompose_reduction primitive.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.reads()
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.3303 0.3303 0.3303 0.3303 0.0000
Trace the Transformation
TensorIR schedule is a procedural language, and the transformation is executed in a step-by-step manner. We can trace the transformation by printing the schedule or the history of the schedule.
We’ve already see the schedule by printing sch.mod
. We can also print the history
of the schedule by sch.trace
.
sch.trace.show()
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)
Alternatively, we can output the IRModule in conjunction with the historical trace.
sch.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.reads()
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)