Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
IRModule
This tutorial presents the core abstraction of Apache TVM Unity, the IRModule. The IRModule encompasses the entirety of the ML models, incorporating the computational graph, tensor programs, and potential calls to external libraries.
import numpy as np
import tvm
from tvm import relax
Create IRModule
IRModules can be initialized in various ways. We demonstrate a few of them below.
import torch
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
Import from existing models
The most common way to initialize an IRModule is to import from an existing model. Apache TVM Unity accommodates imports from a range of frameworks, such as PyTorch and ONNX. This tutorial solely demonstrates the import process from PyTorch.
# Create a dummy model
class TorchModel(nn.Module):
def __init__(self):
super(TorchModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(TorchModel().eval(), example_args)
mod_from_torch = from_exported_program(
exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
)
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
Write with Relax NN Module
Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users write the IRModule directly.
from tvm.relax.frontend import nn
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
mod_from_relax, params_from_relax = RelaxModel().export_tvm(
{"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
Create via TVMScript
TVMScript is a Python-based DSL for IRModules. We are able to directly output the IRModule in the TVMScript syntax, or alternatively, parse the TVMScript to obtain an IRModule.
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class TVMScriptModule:
@R.function
def main(
x: R.Tensor((1, 784), dtype="float32"),
fc1_weight: R.Tensor((256, 784), dtype="float32"),
fc1_bias: R.Tensor((256,), dtype="float32"),
fc2_weight: R.Tensor((10, 256), dtype="float32"),
fc2_bias: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims = R.permute_dims(fc1_weight, axes=None)
matmul = R.matmul(x, permute_dims, out_dtype="void")
add = R.add(matmul, fc1_bias)
relu = R.nn.relu(add)
permute_dims1 = R.permute_dims(fc2_weight, axes=None)
matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
add1 = R.add(matmul1, fc2_bias)
gv = add1
R.output(gv)
return gv
mod_from_script = TVMScriptModule
mod_from_script.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
Attributes of an IRModule
An IRModule is a collection of functions, indexed by GlobalVars.
mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]
We can access the functions in the IRModule by indexing with the GlobalVars or their names
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
Transformations on IRModules
Transformations are the import component of Apache TVM Unity. One transformation takes in an IRModule and outputs another IRModule. We can apply a sequence of transformations to an IRModule to obtain a new IRModule. That is the common way to optimize a model.
In this getting started tutorial, we only demonstrate how to apply transformations to an IRModule. For details of each transformation, please refer to the Transformation API Reference
We first apply LegalizeOps transformation to the IRModule. This transformation
will convert the Relax module into a mixed stage, with both Relax and TensorIR function
within the same module. Meanwhile, the Relax operators will be converted into call_tir
.
mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv1[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv5[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
@T.prim_func(private=True)
def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv2[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv1 = R.call_tir(cls.matmul, (x, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv2 = R.call_tir(cls.add, (lv1, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
lv6 = R.call_tir(cls.add1, (lv5, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
After the transformation, there are much more functions inside the module. Let’s print the global vars again.
print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]
Next, Apache TVM Unity provides a set of default transformation pipelines for users, to simplify the transformation process. We can then apply the default pipeline to the module. The default zero pipeline contains very fundamental transformations, including:
LegalizeOps: This transform converts the Relax operators into call_tir functions with the corresponding TensorIR Functions. After this transform, the IRModule will contain both Relax functions and TensorIR functions.
AnnotateTIROpPattern: This transform annotates the pattern of the TensorIR functions, preparing them for subsequent operator fusion.
FoldConstant: This pass performs constant folding, optimizing operations involving constants.
FuseOps and FuseTIR: These two passes work together to fuse operators based on the patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform both Relax functions and TensorIR functions.
Note
Here, we have applied LegalizeOps twice in the flow. The second time is useless but harmless.
Every passes can be duplicated in the flow, since we ensure the passes can handle all legal IRModule inputs. This design can help users to construct their own pipeline.
mod = relax.get_pipeline("zero")(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
Deploy the IRModule Universally
After the optimization, we can compile the model into a TVM runtime module. Notably, Apache TVM Unity provides the ability of universal deployment, which means we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging backends.
Deploy on CPU
We can deploy the IRModule on CPU by specifying the target as llvm
.
exec = relax.build(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)
raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.16684939 -0.11660337 -0.06417739 -0.01259948 -0.02337324 0.17474481
-0.0235335 0.15044701 0.06827082 -0.14142653]]
Deploy on GPU
Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires programs containing extra information, such as the thread bindings and shared memory allocations. We need a further transformation to generate the GPU programs.
We use DLight
to generate the GPU programs. In this tutorial, we won’t go into
the details of DLight
.
from tvm import dlight as dl
with tvm.target.Target("cuda"):
gpu_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.Fallback(),
)(mod)
Now we can compile the IRModule on GPU, the similar way as we did on CPU.
exec = relax.build(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)
# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
[[-0.16684945 -0.11660337 -0.06417744 -0.01259948 -0.02337325 0.17474481
-0.02353344 0.15044704 0.06827084 -0.1414266 ]]
Deploy on Other Backends
Apache TVM Unity also supports other backends, such as different kinds of GPUs (Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other emerging backends (e.g., WebAssembly). The deployment process is similar to the GPU backend.