Transformation

In this section, we will dive into the transformation of Relax programs. Transformations is one of the key ingredients of the compilation flows for optimizing and integrating with hardware backends.

Let’s first create a simple Relax program as what we have done in the previous section.

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


origin_mod, params = NNModule().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

Apply transformations

Passes are the main way to apply transformations to the program. We can apply passes to the program. As first step, let’s apply a built-in pass LegalizeOps to lower the high-level operators into low-level operators.

mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]

    @T.prim_func(private=True)
    def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(var_x, (n, T.int64(784)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(var_relu, (n, T.int64(128)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(var_add: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        add = T.match_buffer(var_add, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(add[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            matmul = R.call_tir(cls.matmul, (x, permute_dims), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            add = R.call_tir(cls.add, (matmul, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            relu = R.call_tir(cls.relu, (add,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

As we can see from the output, the high-level operators (aka relax.op) in the program are replaced by their corresponding low-level operators (aka relax.call_tir).

Then let’s trying to apply the operator fusion, which is a wide-used optimization technique in ML compilers. Note that in relax, fusion optimizations are done with the collaboration of a set of passes. We can apply them in a sequence.

mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(p_relu, (n, T.int64(128)))
        T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(10)))
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(p_x, (n, T.int64(784)))
        compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(128)))
        T_add_intermediate = T.alloc_buffer((n, T.int64(128)))
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            R.output(gv)
        return gv

As result, we can see that the matmul, add and relu operators are fused into one kernel (aka one call_tir).

For all built-in passes, please refer to relax.transform.

Custom Passes

We can also define our own passes. Let’s taking an example of rewrite the relu operator to gelu operator.

First, we need to write a Relax IR Mutator to do the rewriting.

from tvm.relax.expr_functor import PyExprMutator, mutator


@mutator
class ReluRewriter(PyExprMutator):
    def __init__(self, mod):
        super().__init__(mod)

    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
        if call.op.name == "relax.nn.relu":
            return relax.op.nn.gelu(call.args[0])

        return super().visit_call_(call)

Then we can write a pass to apply the mutator to the whole module.

@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu:  # pylint: disable=too-few-public-methods
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        """IRModule-level transformation"""
        rewriter = ReluRewriter(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = rewriter.visit_expr(func)
                rewriter.builder_.update_func(g_var, func)
        return rewriter.builder_.get()


mod = ReluToGelu()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

The printed output shows that the relax.nn.relu operator is rewritten to relax.nn.gelu operator.

For the details of the mutator, please refer to relax.expr_functor.PyExprMutator.

Summary

In this section, we have shown how to apply transformations to the Relax program. We have also shown how to define and apply custom transformations.

Gallery generated by Sphinx-Gallery