Customize Optimization

One main design goal of Apache TVM is to enable easy customization of the optimization pipeline for both research or development purposes and iterate the engineering optimizations. In this tutorial we will

Review Overall Flow

https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg

The overall flow consists of the following steps:

  • Construct or Import a Model: Construct a neural network model or import a pre-trained model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains all the information needed for compilation, including high-level Relax functions for computational graph, and low-level TensorIR functions for tensor program.

  • Perform Composable Optimizations: Perform a series of optimization transformations, such as graph optimizations, tensor program optimizations, and library dispatching.

  • Build and Universal Deployment: Build the optimized model to a deployable module to the universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.

import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn

Composable IRModule Optimization

Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered around IRModule optimization can be composed with existing pipelines. Note that each optimization can focus on part of the computation graph, enabling partial lowering or partial optimization.

In this tutorial, we will demonstrate how to optimize a model with Apache TVM Unity.

Prepare a Relax Module

We first prepare a Relax module. The module can be imported from other frameworks, constructed with NN module frontend or TVMScript. Here we use a simple neural network model as an example.

class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

Library Dispatch

We would like to quickly try out a variant of library optimization for certain platforms (e.g., GPU). We can write a certain dispatching pass for the specific platform and operator. Here we demonstrate how to dispatch the CUBLAS library for certain patterns.

Note

This tutorial only demonstrates a single operator dispatching for CUBLAS, highlighting the flexibility of the optimization pipeline. In real-world cases, we can import multiple patterns and dispatch them to different kernels.

# Import cublas pattern
import tvm.relax.backend.contrib.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Check if CUBLAS is enabled
        if not tvm.get_global_func("relax.ext.cublas", True):
            raise Exception("CUBLAS is not enabled.")

        # Get interested patterns
        patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
        # Note in real-world cases, we usually get all patterns
        # patterns = relax.backend.get_patterns_with_prefix("cublas")

        # Fuse ops by patterns and then run codegen
        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
        mod = relax.transform.RunCodegen()(mod)
        return mod


mod = CublasDispatch()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

After the dispatching pass, we can see that the first nn.Linear and nn.ReLU are fused and rewritten to a call_dps_packed function which call the CUBLAS library. Notably, the other part is not changed, which means we can selectively dispatch the optimization for certain computation.

Auto Tuning

Continuing from the previous example, we can further optimize the model with auto-tuning for the rest part of the computation. Here we demonstrate how to use the meta-schedule to auto-tune the model.

We can use MetaScheduleTuneTIR pass to simply tuning the model, while MetaScheduleApplyDatabase pass to apply the best configuration to the model. The tuning process will generate search space, tune the model and the following steps will apply the best configuration to the model. Before running the passes, we need to lowering relax operator into TensorIR functions via LegalizeOps

Note

To save CI time and avoid flakiness, we skip the tuning process in CI environment.

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
    trials = 2000
    with target, tempfile.TemporaryDirectory() as tmp_dir:
        mod = tvm.ir.transform.Sequential(
            [
                relax.get_pipeline("zero"),
                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
            ]
        )(mod)

    mod.show()

DLight Rules

DLight rules are a set of default rules for scheduling and optimization the kernel. DLight rules are designed for fast compilation and fair performance. In some cases, e.g. language model, DLight provides excellent performance, while for generic models, it achieves a balance between performance and compilation time.

from tvm import dlight as dl

# Apply DLight rules
with target:
    mod = tvm.ir.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )(mod)

mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @T.prim_func(private=True)
    def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(10)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(16), 1):
                        with T.block("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(16), ax1_fused_0)
                            T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0], lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_rf_local[vax1_fused_1, T.int64(0), v0] + lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax1_fused])
                        T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul[T.int64(0), v0])
                        with T.init():
                            matmul[T.int64(0), v0] = T.float32(0.0)
                        matmul[T.int64(0), v0] = matmul[T.int64(0), v0] + matmul_rf_local[vax1_fused_1, T.int64(0), v0]

    @T.prim_func(private=True)
    def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
                    v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
                    T.reads(fc2_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc2_weight[v1, v0]

    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Note

This tutorial focuses on the demonstration of the optimization pipeline, instead of pushing the performance to the limit. The current optimization may not be the best.

Deploy the Optimized Model

We can build and deploy the optimized model to the TVM runtime.

ex = relax.build(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)
[[26195.57  25223.076 24700.54  25951.34  24422.746 25216.812 24859.023
  24812.828 25638.484 24069.904]]

Summary

This tutorial demonstrates how to customize the optimization pipeline for ML models in Apache TVM. We can easily compose the optimization passes and customize the optimization for different parts of the computation graph. The flexibility of the optimization pipeline enables us to quickly iterate the optimization and improve the performance of the model.

Gallery generated by Sphinx-Gallery