IRModule

This tutorial presents the core abstraction of Apache TVM Unity, the IRModule. The IRModule encompasses the entirety of the ML models, incorporating the computational graph, tensor programs, and potential calls to external libraries.

import numpy as np
import tvm
from tvm import relax

Create IRModule

IRModules can be initialized in various ways. We demonstrate a few of them below.

import torch
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program

Import from existing models

The most common way to initialize an IRModule is to import from an existing model. Apache TVM Unity accommodates imports from a range of frameworks, such as PyTorch and ONNX. This tutorial solely demonstrates the import process from PyTorch.

# Create a dummy model
class TorchModel(nn.Module):
    def __init__(self):
        super(TorchModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)

# Convert the model to IRModule
with torch.no_grad():
    exported_program = export(TorchModel().eval(), example_args)
    mod_from_torch = from_exported_program(
        exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
    )

mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
            lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
            lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
            lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
            lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
            lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

Write with Relax NN Module

Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users write the IRModule directly.

from tvm.relax.frontend import nn


class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


mod_from_relax, params_from_relax = RelaxModel().export_tvm(
    {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

Create via TVMScript

TVMScript is a Python-based DSL for IRModules. We are able to directly output the IRModule in the TVMScript syntax, or alternatively, parse the TVMScript to obtain an IRModule.

from tvm.script import ir as I
from tvm.script import relax as R


@I.ir_module
class TVMScriptModule:
    @R.function
    def main(
        x: R.Tensor((1, 784), dtype="float32"),
        fc1_weight: R.Tensor((256, 784), dtype="float32"),
        fc1_bias: R.Tensor((256,), dtype="float32"),
        fc2_weight: R.Tensor((10, 256), dtype="float32"),
        fc2_bias: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims = R.permute_dims(fc1_weight, axes=None)
            matmul = R.matmul(x, permute_dims, out_dtype="void")
            add = R.add(matmul, fc1_bias)
            relu = R.nn.relu(add)
            permute_dims1 = R.permute_dims(fc2_weight, axes=None)
            matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
            add1 = R.add(matmul1, fc2_bias)
            gv = add1
            R.output(gv)
        return gv


mod_from_script = TVMScriptModule
mod_from_script.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

Attributes of an IRModule

An IRModule is a collection of functions, indexed by GlobalVars.

mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]

We can access the functions in the IRModule by indexing with the GlobalVars or their names

# index by global var name
print(mod["main"])
# index by global var, and checking they are the same function
(gv,) = mod.get_global_vars()
assert mod[gv] == mod["main"]
# from tvm.script import relax as R

@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
    R.func_attr({"num_input": 1})
    with R.dataflow():
        lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
        lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
        lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
        lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
        lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
        lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
        lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
        gv: R.Tensor((1, 10), dtype="float32") = lv6
        R.output(gv)
    return gv

Transformations on IRModules

Transformations are the import component of Apache TVM Unity. One transformation takes in an IRModule and outputs another IRModule. We can apply a sequence of transformations to an IRModule to obtain a new IRModule. That is the common way to optimize a model.

In this getting started tutorial, we only demonstrate how to apply transformations to an IRModule. For details of each transformation, please refer to the Transformation API Reference

We first apply LegalizeOps transformation to the IRModule. This transformation will convert the Relax module into a mixed stage, with both Relax and TensorIR function within the same module. Meanwhile, the Relax operators will be converted into call_tir.

mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv1[v_ax0, v_ax1], p_fc1_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + p_fc1_bias[v_ax1]

    @T.prim_func(private=True)
    def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv5[v_ax0, v_ax1], p_fc2_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + p_fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv2[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv1 = R.call_tir(cls.matmul, (x, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv2 = R.call_tir(cls.add, (lv1, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            lv6 = R.call_tir(cls.add1, (lv5, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

After the transformation, there are much more functions inside the module. Let’s print the global vars again.

print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]

Next, Apache TVM Unity provides a set of default transformation pipelines for users, to simplify the transformation process. We can then apply the default pipeline to the module. The default zero pipeline contains very fundamental transformations, including:

  • LegalizeOps: This transform converts the Relax operators into call_tir functions with the corresponding TensorIR Functions. After this transform, the IRModule will contain both Relax functions and TensorIR functions.

  • AnnotateTIROpPattern: This transform annotates the pattern of the TensorIR functions, preparing them for subsequent operator fusion.

  • FoldConstant: This pass performs constant folding, optimizing operations involving constants.

  • FuseOps and FuseTIR: These two passes work together to fuse operators based on the patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform both Relax functions and TensorIR functions.

Note

Here, we have applied LegalizeOps twice in the flow. The second time is useless but harmless.

Every passes can be duplicated in the flow, since we ensure the passes can handle all legal IRModule inputs. This design can help users to construct their own pipeline.

mod = relax.get_pipeline("zero")(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc2_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc1_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

Deploy the IRModule Universally

After the optimization, we can compile the model into a TVM runtime module. Notably, Apache TVM Unity provides the ability of universal deployment, which means we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging backends.

Deploy on CPU

We can deploy the IRModule on CPU by specifying the target as llvm.

exec = relax.build(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)

raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.16684939 -0.11660337 -0.06417739 -0.01259948 -0.02337324  0.17474481
  -0.0235335   0.15044701  0.06827082 -0.14142653]]

Deploy on GPU

Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires programs containing extra information, such as the thread bindings and shared memory allocations. We need a further transformation to generate the GPU programs.

We use DLight to generate the GPU programs. In this tutorial, we won’t go into the details of DLight.

from tvm import dlight as dl

with tvm.target.Target("cuda"):
    gpu_mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.Fallback(),
    )(mod)

Now we can compile the IRModule on GPU, the similar way as we did on CPU.

exec = relax.build(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)

# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
[[-0.16684945 -0.11660337 -0.06417744 -0.01259948 -0.02337325  0.17474481
  -0.02353344  0.15044704  0.06827084 -0.1414266 ]]

Deploy on Other Backends

Apache TVM Unity also supports other backends, such as different kinds of GPUs (Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other emerging backends (e.g., WebAssembly). The deployment process is similar to the GPU backend.

Gallery generated by Sphinx-Gallery