Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Transformation
In this section, we will dive into the transformation of Relax programs. Transformations is one of the key ingredients of the compilation flows for optimizing and integrating with hardware backends.
Let’s first create a simple Relax program as what we have done in the previous section.
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
class NNModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
origin_mod, params = NNModule().export_tvm(
{"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
Apply transformations
Passes are the main way to apply transformations to the program.
We can apply passes to the program. As first step, let’s apply
a built-in pass LegalizeOps
to lower the high-level operators
into low-level operators.
mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
# with T.block("root"):
for ax0, ax1 in T.grid(n, T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
x = T.match_buffer(var_x, (n, T.int64(784)))
matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
# with T.block("root"):
for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
relu = T.match_buffer(var_relu, (n, T.int64(128)))
matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
# with T.block("root"):
for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
@T.prim_func(private=True)
def relu(var_add: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
add = T.match_buffer(var_add, (n, T.int64(128)))
compute = T.match_buffer(var_compute, (n, T.int64(128)))
# with T.block("root"):
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(add[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
matmul = R.call_tir(cls.matmul, (x, permute_dims), out_sinfo=R.Tensor((n, 128), dtype="float32"))
add = R.call_tir(cls.add, (matmul, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
relu = R.call_tir(cls.relu, (add,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
As we can see from the output, the high-level operators (aka relax.op
) in the program
are replaced by their corresponding low-level operators (aka relax.call_tir
).
Then let’s trying to apply the operator fusion, which is a wide-used optimization technique in ML compilers. Note that in relax, fusion optimizations are done with the collaboration of a set of passes. We can apply them in a sequence.
mod = tvm.ir.transform.Sequential(
[
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
]
)(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
relu = T.match_buffer(p_relu, (n, T.int64(128)))
T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((n, T.int64(10)))
for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
for ax0, ax1 in T.grid(n, T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
x = T.match_buffer(p_x, (n, T.int64(784)))
compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((n, T.int64(128)))
T_add_intermediate = T.alloc_buffer((n, T.int64(128)))
for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
for ax0, ax1 in T.grid(n, T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
R.output(gv)
return gv
As result, we can see that the matmul
, add
and relu
operators are fused
into one kernel (aka one call_tir
).
For all built-in passes, please refer to relax.transform
.
Custom Passes
We can also define our own passes. Let’s taking an example of rewrite the relu
operator to gelu
operator.
First, we need to write a Relax IR Mutator to do the rewriting.
from tvm.relax.expr_functor import PyExprMutator, mutator
@mutator
class ReluRewriter(PyExprMutator):
def __init__(self, mod):
super().__init__(mod)
def visit_call_(self, call: relax.Call) -> relax.Expr:
# visit the relax.Call expr, and only handle the case when op is relax.nn.relu
if call.op.name == "relax.nn.relu":
return relax.op.nn.gelu(call.args[0])
return super().visit_call_(call)
Then we can write a pass to apply the mutator to the whole module.
@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu: # pylint: disable=too-few-public-methods
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
rewriter = ReluRewriter(mod)
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
func = rewriter.visit_expr(func)
rewriter.builder_.update_func(g_var, func)
return rewriter.builder_.get()
mod = ReluToGelu()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
The printed output shows that the relax.nn.relu
operator is
rewritten to relax.nn.gelu
operator.
For the details of the mutator, please refer to relax.expr_functor.PyExprMutator
.
Summary
In this section, we have shown how to apply transformations to the Relax program. We have also shown how to define and apply custom transformations.