Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Relax Creation
This tutorial demonstrates how to create Relax functions and programs. We’ll cover various ways to define Relax functions, including using TVMScript, and relax NNModule API.
Create Relax programs using TVMScript
TVMScript is a domain-specific language for representing Apache TVM’s intermediate representation (IR). It is a Python dialect that can be used to define an IRModule, which contains both TensorIR and Relax functions.
In this section, we will show how to define a simple MLP model with only high-level Relax operators using TVMScript.
from tvm import relax, topi
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@I.ir_module
class RelaxModule:
@R.function
def forward(
data: R.Tensor(("n", 784), dtype="float32"),
w0: R.Tensor((128, 784), dtype="float32"),
b0: R.Tensor((128,), dtype="float32"),
w1: R.Tensor((10, 128), dtype="float32"),
b1: R.Tensor((10,), dtype="float32"),
) -> R.Tensor(("n", 10), dtype="float32"):
with R.dataflow():
lv0 = R.matmul(data, R.permute_dims(w0)) + b0
lv1 = R.nn.relu(lv0)
lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
R.output(lv2)
return lv2
RelaxModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
lv1_1: R.Tensor((n, 128), dtype="float32") = R.nn.relu(lv0)
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
R.output(lv2)
return lv2
Relax is not only a graph-level IR, but also supports cross-level representation and transformation. To be specific, we can directly call TensorIR functions in Relax function.
@I.ir_module
class RelaxModuleWithTIR:
@T.prim_func
def relu(x: T.handle, y: T.handle):
n, m = T.int64(), T.int64()
X = T.match_buffer(x, (n, m), "float32")
Y = T.match_buffer(y, (n, m), "float32")
for i, j in T.grid(n, m):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
Y[vi, vj] = T.max(X[vi, vj], T.float32(0))
@R.function
def forward(
data: R.Tensor(("n", 784), dtype="float32"),
w0: R.Tensor((128, 784), dtype="float32"),
b0: R.Tensor((128,), dtype="float32"),
w1: R.Tensor((10, 128), dtype="float32"),
b1: R.Tensor((10,), dtype="float32"),
) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
cls = RelaxModuleWithTIR
with R.dataflow():
lv0 = R.matmul(data, R.permute_dims(w0)) + b0
lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32"))
lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
R.output(lv2)
return lv2
RelaxModuleWithTIR.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def relu(x: T.handle, y: T.handle):
n, m = T.int64(), T.int64()
X = T.match_buffer(x, (n, m))
Y = T.match_buffer(y, (n, m))
# with T.block("root"):
for i, j in T.grid(n, m):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(X[vi, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = T.max(X[vi, vj], T.float32(0.0))
@R.function
def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
cls = Module
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
lv1_1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
R.output(lv2)
return lv2
Note
You may notice that the printed output is different from the written TVMScript code. This is because we print the IRModule in a standard format, while we support syntax sugar for the input
For example, we can combine multiple operators into a single line, as
lv0 = R.matmul(data, R.permute_dims(w0)) + b0
However, the normalized expression requires only one operation in one binding. So the printed output is different from the written TVMScript code, as
Create Relax programs using NNModule API
Besides TVMScript, we also provide a PyTorch-like API for defining neural networks. It is designed to be more intuitive and easier to use than TVMScript.
In this section, we will show how to define the same MLP model using Relax NNModule API.
After we define the NNModule, we can export it to TVM IRModule via
export_tvm
.
mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((n, 10), dtype="float32") = add1
R.output(gv)
return gv
We can also insert customized function calls into the NNModule, such as Tensor Expression(TE), TensorIR functions or other TVM packed functions.
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
M, N, K = T.int64(), T.int64(), T.int64()
X = T.match_buffer(x, (M, K), "float32")
W = T.match_buffer(w, (N, K), "float32")
B = T.match_buffer(b, (N,), "float32")
Z = T.match_buffer(z, (M, N), "float32")
for i, j, k in T.grid(M, N, K):
with T.block("linear"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Z[vi, vj] = 0
Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(M, N):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Z[vi, vj] + B[vj]
class NNModuleWithTIR(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
n = x.shape[0]
# We can call external functions using nn.extern
x = nn.extern(
"env.linear",
[x, self.fc1.weight, self.fc1.bias],
out=nn.Tensor.placeholder((n, 128), "float32"),
)
# We can also call TensorIR via Tensor Expression API in TOPI
x = nn.tensor_expr_op(topi.nn.relu, "relu", [x])
# We can also call other TVM packed functions
x = nn.tensor_ir_op(
tir_linear,
"tir_linear",
[x, self.fc2.weight, self.fc2.bias],
out=nn.Tensor.placeholder((n, 10), "float32"),
)
return x
mod, params = NNModuleWithTIR().export_tvm(
{"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def relu(var_env_linear: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
env_linear = T.match_buffer(var_env_linear, (n, T.int64(128)))
compute = T.match_buffer(var_compute, (n, T.int64(128)))
# with T.block("root"):
for i0, i1 in T.grid(n, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(env_linear[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(env_linear[v_i0, v_i1], T.float32(0.0))
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
M, K = T.int64(), T.int64()
X = T.match_buffer(x, (M, K))
N = T.int64()
W = T.match_buffer(w, (N, K))
B = T.match_buffer(b, (N,))
Z = T.match_buffer(z, (M, N))
# with T.block("root"):
for i, j, k in T.grid(M, N, K):
with T.block("linear"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(X[vi, vk], W[vj, vk])
T.writes(Z[vi, vj])
with T.init():
Z[vi, vj] = T.float32(0.0)
Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(M, N):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Z[vi, vj], B[vj])
T.writes(Z[vi, vj])
Z[vi, vj] = Z[vi, vj] + B[vj]
@R.function
def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
env_linear = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
lv = R.call_tir(cls.relu, (env_linear,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
lv1 = R.call_tir(cls.tir_linear, (lv, fc2_weight, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
gv: R.Tensor((n, 10), dtype="float32") = lv1
R.output(gv)
return gv
Create Relax programs using Block Builder API
In addition to the above APIs, we also provide a Block Builder API for creating Relax programs. It is a IR builder API, which is more low-level and widely used in TVM’s internal logic, e.g writing a customized pass.
bb = relax.BlockBuilder()
n = T.int64()
x = relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32"))
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
with bb.dataflow():
lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias)
lv1 = bb.emit(relax.op.nn.relu(lv0))
gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias)
bb.emit_output(gv)
bb.emit_func_output(gv)
mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
v = T.int64()
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
lv1: R.Tensor((v, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
lv2: R.Tensor((v, 128), dtype="float32") = R.add(lv1, fc1_bias)
lv3: R.Tensor((v, 128), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
lv5: R.Tensor((v, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
lv6: R.Tensor((v, 10), dtype="float32") = R.add(lv5, fc2_bias)
gv: R.Tensor((v, 10), dtype="float32") = lv6
R.output(gv)
return lv6
Also, Block Builder API supports building cross-level IRModule with both Relax functions, TensorIR functions and other TVM packed functions.
bb = relax.BlockBuilder()
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
with bb.dataflow():
lv0 = bb.emit(
relax.call_dps_packed(
"env.linear",
[x, fc1_weight, fc1_bias],
out_sinfo=relax.TensorStructInfo((n, 128), "float32"),
)
)
lv1 = bb.emit_te(topi.nn.relu, lv0)
tir_gv = bb.add_func(tir_linear, "tir_linear")
gv = bb.emit(
relax.call_tir(
tir_gv,
[lv1, fc2_weight, fc2_bias],
out_sinfo=relax.TensorStructInfo((n, 10), "float32"),
)
)
bb.emit_output(gv)
bb.emit_func_output(gv)
mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def relu(var_lv: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
v = T.int64()
lv = T.match_buffer(var_lv, (v, T.int64(128)))
compute = T.match_buffer(var_compute, (v, T.int64(128)))
# with T.block("root"):
for i0, i1 in T.grid(v, T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv[v_i0, v_i1], T.float32(0.0))
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
M, K = T.int64(), T.int64()
X = T.match_buffer(x, (M, K))
N = T.int64()
W = T.match_buffer(w, (N, K))
B = T.match_buffer(b, (N,))
Z = T.match_buffer(z, (M, N))
# with T.block("root"):
for i, j, k in T.grid(M, N, K):
with T.block("linear"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(X[vi, vk], W[vj, vk])
T.writes(Z[vi, vj])
with T.init():
Z[vi, vj] = T.float32(0.0)
Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(M, N):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Z[vi, vj], B[vj])
T.writes(Z[vi, vj])
Z[vi, vj] = Z[vi, vj] + B[vj]
@R.function
def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
v = T.int64()
cls = Module
with R.dataflow():
lv = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((v, 128), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((v, 128), dtype="float32"))
lv2 = R.call_tir(cls.tir_linear, (lv1, fc2_weight, fc2_bias), out_sinfo=R.Tensor((v, 10), dtype="float32"))
gv: R.Tensor((v, 10), dtype="float32") = lv2
R.output(gv)
return lv2
Note that the Block Builder API is not as user-friendly as the above APIs, but it is lowest-level API and works closely with the IR definition. We recommend using the above APIs for users who only want to define and transform a ML model. But for those who want to build more complex transformations, the Block Builder API is a more flexible choice.
Summary
This tutorial demonstrates how to create Relax programs using TVMScript, NNModule API, Block Builder API and PackedFunc API for different use cases.