Customize Optimization
One main design goal of Apache TVM is to enable easy customization of the optimization pipeline for both research or development purposes and iterate the engineering optimizations. In this tutorial we will
Review Overall Flow
The overall flow consists of the following steps:
Construct or Import a Model: Construct a neural network model or import a pre-trained model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains all the information needed for compilation, including high-level Relax functions for computational graph, and low-level TensorIR functions for tensor program.
Perform Composable Optimizations: Perform a series of optimization transformations, such as graph optimizations, tensor program optimizations, and library dispatching.
Build and Universal Deployment: Build the optimized model to a deployable module to the universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.
import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
Composable IRModule Optimization
Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered around IRModule optimization can be composed with existing pipelines. Note that each optimization can focus on part of the computation graph, enabling partial lowering or partial optimization.
In this tutorial, we will demonstrate how to optimize a model with Apache TVM Unity.
Prepare a Relax Module
We first prepare a Relax module. The module can be imported from other frameworks, constructed with NN module frontend or TVMScript. Here we use a simple neural network model as an example.
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
# from tvm.script import ir as I
# from tvm.script import relax as R
class Module:
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = matmul1
return gv
Library Dispatch
We would like to quickly try out a variant of library optimization for certain platforms (e.g., GPU). We can write a certain dispatching pass for the specific platform and operator. Here we demonstrate how to dispatch the CUBLAS library for certain patterns.
This tutorial only demonstrates a single operator dispatching for CUBLAS, highlighting the flexibility of the optimization pipeline. In real-world cases, we can import multiple patterns and dispatch them to different kernels.
# Import cublas pattern
import tvm.relax.backend.cuda.cublas as _cublas
# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
# Check if CUBLAS is enabled
if not tvm.get_global_func("relax.ext.cublas", True):
raise Exception("CUBLAS is not enabled.")
# Get interested patterns
patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
# Note in real-world cases, we usually get all patterns
# patterns = relax.backend.get_patterns_with_prefix("cublas")
# Fuse ops by patterns and then run codegen
mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
mod = relax.transform.RunCodegen()(mod)
return mod
mod = CublasDispatch()(mod)
# from tvm.script import ir as I
# from tvm.script import relax as R
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = matmul1
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
After the dispatching pass, we can see that the first nn.Linear
and nn.ReLU
are fused
and rewritten to a call_dps_packed
function which call the CUBLAS library. Notably, the
other part is not changed, which means we can selectively dispatch the optimization for
certain computation.
Auto Tuning
Continuing from the previous example, we can further optimize the model with auto-tuning for the rest part of the computation. Here we demonstrate how to use the meta-schedule to auto-tune the model.
We can use MetaScheduleTuneTIR
pass to simply tuning the model, while MetaScheduleApplyDatabase
pass to apply the best configuration to the model. The tuning process will generate search space,
tune the model and the following steps will apply the best configuration to the model. Before
running the passes, we need to lowering relax operator into TensorIR functions via LegalizeOps
To save CI time and avoid flakiness, we skip the tuning process in CI environment.
device = tvm.cuda(0)
target =
if os.getenv("CI", "") != "true":
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
mod =
relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
DLight Rules
DLight rules are a set of default rules for scheduling and optimization the kernel. DLight rules are designed for fast compilation and fair performance. In some cases, e.g. language model, DLight provides excellent performance, while for generic models, it achieves a balance between performance and compilation time.
from tvm import dlight as dl
# Apply DLight rules
with target:
mod =
dl.ApplyDefaultSchedule( # pylint: disable=not-callable
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(10)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
for ax1_fused_0, u in T.grid(T.int64(16), 1):
with T.block("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(16), ax1_fused_0)
T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0], lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_rf_local[vax1_fused_1, T.int64(0), v0] + lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul"):
vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax1_fused])
T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul[T.int64(0), v0])
with T.init():
matmul[T.int64(0), v0] = T.float32(0.0)
matmul[T.int64(0), v0] = matmul[T.int64(0), v0] + matmul_rf_local[vax1_fused_1, T.int64(0), v0]
def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.block("T_transpose"):
v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
T.reads(fc2_weight[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = fc2_weight[v1, v0]
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
This tutorial focuses on the demonstration of the optimization pipeline, instead of pushing the performance to the limit. The current optimization may not be the best.
Deploy the Optimized Model
We can build and deploy the optimized model to the TVM runtime.
ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
[[25178.377 24039.607 26198.518 24629.71 25687.209 24261.074 25842.506
26051.04 25509.438 25442.607]]
This tutorial demonstrates how to customize the optimization pipeline for ML models in Apache TVM. We can easily compose the optimization passes and customize the optimization for different parts of the computation graph. The flexibility of the optimization pipeline enables us to quickly iterate the optimization and improve the performance of the model.