End-to-End Optimize Model

This tutorial demonstrates how to optimize a machine learning model using Apache TVM. We will use a pre-trained ResNet-18 model from PyTorch and end-to-end optimize it using TVM’s Relax API. Please note that default end-to-end optimization may not suit complex models.

Preparation

First, we prepare the model and input information. We use a pre-trained ResNet-18 model from PyTorch.

import os
import numpy as np
import torch
from torch import fx
from torchvision.models.resnet import ResNet18_Weights, resnet18

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

Review Overall Flow

https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg

The overall flow consists of the following steps:

  • Construct or Import a Model: Construct a neural network model or import a pre-trained model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains all the information needed for compilation, including high-level Relax functions for computational graph, and low-level TensorIR functions for tensor program.

  • Perform Composable Optimizations: Perform a series of optimization transformations, such as graph optimizations, tensor program optimizations, and library dispatching.

  • Build and Universal Deployment: Build the optimized model to a deployable module to the universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.

Convert the model to IRModule

Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further optimization. Besides the model, we also need to provide the input shape and data type.

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

# Give the input shape and data type
input_info = [((1, 3, 224, 224), "float32")]

# Convert the model to IRModule
with torch.no_grad():
    torch_fx_model = fx.symbolic_trace(torch_model)
    mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)

mod, params = relax.frontend.detach_params(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32"), bn1_bias: R.Tensor((64,), dtype="float32"), bn1_weight: R.Tensor((64,), dtype="float32"), conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), fc_bias: R.Tensor((1000,), dtype="float32"), fc_weight: R.Tensor((1000, 512), dtype="float32"), layer1_0_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_0_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_0_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_0_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_1_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_1_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer2_0_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_0_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_0_conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), layer2_0_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_0_downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), layer2_0_downsample_1_bias: R.Tensor((128,), dtype="float32"), layer2_0_downsample_1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_1_conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer3_0_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_0_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_0_conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), layer3_0_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_0_downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), layer3_0_downsample_1_bias: R.Tensor((256,), dtype="float32"), layer3_0_downsample_1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_1_conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer4_0_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_0_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_0_conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), layer4_0_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_0_downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), layer4_0_downsample_1_bias: R.Tensor((512,), dtype="float32"), layer4_0_downsample_1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_1_conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(inp_0, conv1_weight, strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tuple(R.Tensor((1, 64, 112, 112), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv, bn1_weight, bn1_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = lv1[0]
            lv3: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv3, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv5: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv4, layer1_0_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv6: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv5, layer1_0_bn1_weight, layer1_0_bn1_bias, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv7: R.Tensor((1, 64, 56, 56), dtype="float32") = lv6[0]
            lv8: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv7)
            lv9: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv8, layer1_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv10: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv9, layer1_0_bn2_weight, layer1_0_bn2_bias, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv11: R.Tensor((1, 64, 56, 56), dtype="float32") = lv10[0]
            lv12: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv11, lv4)
            lv13: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv12)
            lv14: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv13, layer1_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv15: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv14, layer1_1_bn1_weight, layer1_1_bn1_bias, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = lv15[0]
            lv17: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv16)
            lv18: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv17, layer1_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv19: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv18, layer1_1_bn2_weight, layer1_1_bn2_bias, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv20: R.Tensor((1, 64, 56, 56), dtype="float32") = lv19[0]
            lv21: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv20, lv13)
            lv22: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv21)
            lv23: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv24: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv23, layer2_0_bn1_weight, layer2_0_bn1_bias, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv25: R.Tensor((1, 128, 28, 28), dtype="float32") = lv24[0]
            lv26: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv25)
            lv27: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv26, layer2_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv28: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv27, layer2_0_bn2_weight, layer2_0_bn2_bias, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv29: R.Tensor((1, 128, 28, 28), dtype="float32") = lv28[0]
            lv30: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv31: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv30, layer2_0_downsample_1_weight, layer2_0_downsample_1_bias, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv32: R.Tensor((1, 128, 28, 28), dtype="float32") = lv31[0]
            lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv29, lv32)
            lv34: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv33)
            lv35: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv34, layer2_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv36: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv35, layer2_1_bn1_weight, layer2_1_bn1_bias, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = lv36[0]
            lv38: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv37)
            lv39: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv38, layer2_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv40: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv39, layer2_1_bn2_weight, layer2_1_bn2_bias, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv41: R.Tensor((1, 128, 28, 28), dtype="float32") = lv40[0]
            lv42: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv41, lv34)
            lv43: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv42)
            lv44: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv45: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv44, layer3_0_bn1_weight, layer3_0_bn1_bias, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv46: R.Tensor((1, 256, 14, 14), dtype="float32") = lv45[0]
            lv47: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv46)
            lv48: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv47, layer3_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv49: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv48, layer3_0_bn2_weight, layer3_0_bn2_bias, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv50: R.Tensor((1, 256, 14, 14), dtype="float32") = lv49[0]
            lv51: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv52: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv51, layer3_0_downsample_1_weight, layer3_0_downsample_1_bias, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv53: R.Tensor((1, 256, 14, 14), dtype="float32") = lv52[0]
            lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv50, lv53)
            lv55: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv54)
            lv56: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv55, layer3_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv57: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv56, layer3_1_bn1_weight, layer3_1_bn1_bias, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = lv57[0]
            lv59: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv58)
            lv60: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv59, layer3_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv61: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv60, layer3_1_bn2_weight, layer3_1_bn2_bias, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv62: R.Tensor((1, 256, 14, 14), dtype="float32") = lv61[0]
            lv63: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv62, lv55)
            lv64: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv63)
            lv65: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv66: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv65, layer4_0_bn1_weight, layer4_0_bn1_bias, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv67: R.Tensor((1, 512, 7, 7), dtype="float32") = lv66[0]
            lv68: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv67)
            lv69: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv68, layer4_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv70: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv69, layer4_0_bn2_weight, layer4_0_bn2_bias, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv71: R.Tensor((1, 512, 7, 7), dtype="float32") = lv70[0]
            lv72: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv73: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv72, layer4_0_downsample_1_weight, layer4_0_downsample_1_bias, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv74: R.Tensor((1, 512, 7, 7), dtype="float32") = lv73[0]
            lv75: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv71, lv74)
            lv76: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv75)
            lv77: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv76, layer4_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv78: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv77, layer4_1_bn1_weight, layer4_1_bn1_bias, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv79: R.Tensor((1, 512, 7, 7), dtype="float32") = lv78[0]
            lv80: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv79)
            lv81: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv80, layer4_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv82: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv81, layer4_1_bn2_weight, layer4_1_bn2_bias, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv83: R.Tensor((1, 512, 7, 7), dtype="float32") = lv82[0]
            lv84: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv83, lv76)
            lv85: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv84)
            lv86: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv85, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
            lv87: R.Tensor((1, 512), dtype="float32") = R.reshape(lv86, R.shape([1, 512]))
            lv88: R.Tensor((512, 1000), dtype="float32") = R.permute_dims(fc_weight, axes=None)
            lv89: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv87, lv88, out_dtype="float32")
            lv90: R.Tensor((1, 1000), dtype="float32") = R.add(lv89, fc_bias)
            gv: R.Tensor((1, 1000), dtype="float32") = lv90
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

IRModule Optimization

Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered around IRModule optimization can be composed with existing pipelines. Note that each transformation can be combined as an optimization pipeline via tvm.ir.transform.Sequential.

In this tutorial, we focus on the end-to-end optimization of the model via auto-tuning. We leverage MetaSchedule to tune the model and store the tuning logs to the database. We also apply the database to the model to get the best performance.

TOTAL_TRIALS = 8000  # Change to 20000 for better performance if needed
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your target device
work_dir = "tuning_logs"

# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"
if not IS_IN_CI:
    with target:
        mod = tvm.ir.transform.Sequential(
            [
                # Convert BatchNorm into a sequence of simpler ops for fusion
                relax.transform.DecomposeOpsForInference(),
                # Canonicalize the bindings
                relax.transform.CanonicalizeBindings(),
                # Run default optimization pipeline
                relax.get_pipeline("zero"),
                # Tune the model and store the log to database
                relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS),
                # Apply the database
                relax.transform.MetaScheduleApplyDatabase(work_dir),
            ]
        )(mod)

    # Only show the main function
    mod["main"].show()

Build and Deploy

Finally, we build the optimized model and deploy it to the target device. We skip this step in the CI environment.

if not IS_IN_CI:
    ex = relax.build(mod, target="cuda")
    dev = tvm.device("cuda", 0)
    vm = relax.VirtualMachine(ex, dev)
    # Need to allocate data and params on GPU device
    gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
    gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
    gpu_out = vm["main"](gpu_data, *gpu_params).numpy()

    print(gpu_out.shape)

Gallery generated by Sphinx-Gallery