Relax Creation

This tutorial demonstrates how to create Relax functions and programs. We’ll cover various ways to define Relax functions, including using TVMScript, and relax NNModule API.

Create Relax programs using TVMScript

TVMScript is a domain-specific language for representing Apache TVM’s intermediate representation (IR). It is a Python dialect that can be used to define an IRModule, which contains both TensorIR and Relax functions.

In this section, we will show how to define a simple MLP model with only high-level Relax operators using TVMScript.

from tvm import relax, topi
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


@I.ir_module
class RelaxModule:
    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.nn.relu(lv0)
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1: R.Tensor((n, 128), dtype="float32") = R.nn.relu(lv0)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

Relax is not only a graph-level IR, but also supports cross-level representation and transformation. To be specific, we can directly call TensorIR functions in Relax function.

@I.ir_module
class RelaxModuleWithTIR:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m), "float32")
        Y = T.match_buffer(y, (n, m), "float32")
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = RelaxModuleWithTIR
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32"))
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModuleWithTIR.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m))
        Y = T.match_buffer(y, (n, m))
        # with T.block("root"):
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(X[vi, vj])
                T.writes(Y[vi, vj])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0.0))

    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = Module
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

Note

You may notice that the printed output is different from the written TVMScript code. This is because we print the IRModule in a standard format, while we support syntax sugar for the input

For example, we can combine multiple operators into a single line, as

lv0 = R.matmul(data, R.permute_dims(w0)) + b0

However, the normalized expression requires only one operation in one binding. So the printed output is different from the written TVMScript code, as

lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)

Create Relax programs using NNModule API

Besides TVMScript, we also provide a PyTorch-like API for defining neural networks. It is designed to be more intuitive and easier to use than TVMScript.

In this section, we will show how to define the same MLP model using Relax NNModule API.

from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

After we define the NNModule, we can export it to TVM IRModule via export_tvm.

mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

We can also insert customized function calls into the NNModule, such as Tensor Expression(TE), TensorIR functions or other TVM packed functions.

@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
    M, N, K = T.int64(), T.int64(), T.int64()
    X = T.match_buffer(x, (M, K), "float32")
    W = T.match_buffer(w, (N, K), "float32")
    B = T.match_buffer(b, (N,), "float32")
    Z = T.match_buffer(z, (M, N), "float32")
    for i, j, k in T.grid(M, N, K):
        with T.block("linear"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                Z[vi, vj] = 0
            Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
    for i, j in T.grid(M, N):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            Z[vi, vj] = Z[vi, vj] + B[vj]


class NNModuleWithTIR(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        n = x.shape[0]
        # We can call external functions using nn.extern
        x = nn.extern(
            "env.linear",
            [x, self.fc1.weight, self.fc1.bias],
            out=nn.Tensor.placeholder((n, 128), "float32"),
        )
        # We can also call TensorIR via Tensor Expression API in TOPI
        x = nn.tensor_expr_op(topi.nn.relu, "relu", [x])
        # We can also call other TVM packed functions
        x = nn.tensor_ir_op(
            tir_linear,
            "tir_linear",
            [x, self.fc2.weight, self.fc2.bias],
            out=nn.Tensor.placeholder((n, 10), "float32"),
        )
        return x


mod, params = NNModuleWithTIR().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_env_linear: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        env_linear = T.match_buffer(var_env_linear, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(env_linear[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(env_linear[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            env_linear = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv = R.call_tir(cls.relu, (env_linear,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv1 = R.call_tir(cls.tir_linear, (lv, fc2_weight, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = lv1
            R.output(gv)
        return gv

Create Relax programs using Block Builder API

In addition to the above APIs, we also provide a Block Builder API for creating Relax programs. It is a IR builder API, which is more low-level and widely used in TVM’s internal logic, e.g writing a customized pass.

bb = relax.BlockBuilder()
n = T.int64()
x = relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32"))
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias)
        lv1 = bb.emit(relax.op.nn.relu(lv0))
        gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias)
        bb.emit_output(gv)
    bb.emit_func_output(gv)

mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            lv1: R.Tensor((v, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
            lv2: R.Tensor((v, 128), dtype="float32") = R.add(lv1, fc1_bias)
            lv3: R.Tensor((v, 128), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            lv5: R.Tensor((v, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
            lv6: R.Tensor((v, 10), dtype="float32") = R.add(lv5, fc2_bias)
            gv: R.Tensor((v, 10), dtype="float32") = lv6
            R.output(gv)
        return lv6

Also, Block Builder API supports building cross-level IRModule with both Relax functions, TensorIR functions and other TVM packed functions.

bb = relax.BlockBuilder()
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(
            relax.call_dps_packed(
                "env.linear",
                [x, fc1_weight, fc1_bias],
                out_sinfo=relax.TensorStructInfo((n, 128), "float32"),
            )
        )
        lv1 = bb.emit_te(topi.nn.relu, lv0)
        tir_gv = bb.add_func(tir_linear, "tir_linear")
        gv = bb.emit(
            relax.call_tir(
                tir_gv,
                [lv1, fc2_weight, fc2_bias],
                out_sinfo=relax.TensorStructInfo((n, 10), "float32"),
            )
        )
        bb.emit_output(gv)
    bb.emit_func_output(gv)
mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_lv: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        v = T.int64()
        lv = T.match_buffer(var_lv, (v, T.int64(128)))
        compute = T.match_buffer(var_compute, (v, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(v, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv2 = R.call_tir(cls.tir_linear, (lv1, fc2_weight, fc2_bias), out_sinfo=R.Tensor((v, 10), dtype="float32"))
            gv: R.Tensor((v, 10), dtype="float32") = lv2
            R.output(gv)
        return lv2

Note that the Block Builder API is not as user-friendly as the above APIs, but it is lowest-level API and works closely with the IR definition. We recommend using the above APIs for users who only want to define and transform a ML model. But for those who want to build more complex transformations, the Block Builder API is a more flexible choice.

Summary

This tutorial demonstrates how to create Relax programs using TVMScript, NNModule API, Block Builder API and PackedFunc API for different use cases.

Gallery generated by Sphinx-Gallery