Deploy a Framework-prequantized Model with TVM

Author: Masahiro Masuda

This is a tutorial on loading models quantized by deep learning frameworks into TVM. Pre-quantized model import is one of the quantization support we have in TVM. More details on the quantization story in TVM can be found here.

Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. Once loaded, we can run compiled, quantized models on any hardware TVM supports.

First, necessary imports

from PIL import Image

import numpy as np

import torch
from torchvision.models.quantization import mobilenet as qmobilenet

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata

Helper functions to run the demo

def get_transform():
    import torchvision.transforms as transforms

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]
    )


def get_real_image(im_height, im_width):
    img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
    img_path = download_testdata(img_url, "cat.png", module="data")
    return Image.open(img_path).resize((im_height, im_width))


def get_imagenet_input():
    im = get_real_image(224, 224)
    preprocess = get_transform()
    pt_tensor = preprocess(im)
    return np.expand_dims(pt_tensor.numpy(), 0)


def get_synset():
    synset_url = "".join(
        [
            "https://gist.githubusercontent.com/zhreshold/",
            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
            "imagenet1000_clsid_to_human.txt",
        ]
    )
    synset_name = "imagenet1000_clsid_to_human.txt"
    synset_path = download_testdata(synset_url, synset_name, module="data")
    with open(synset_path) as f:
        return eval(f.read())


def run_tvm_model(mod, params, input_name, inp, target="llvm"):
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)

    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))

    runtime.set_input(input_name, inp)
    runtime.run()
    return runtime.get_output(0).numpy(), runtime

A mapping from label to class name, to verify that the outputs from models below are reasonable

synset = get_synset()

Everyone’s favorite cat image for demonstration

inp = get_imagenet_input()

Deploy a quantized PyTorch Model

First, we demonstrate how to load deep learning models quantized by PyTorch, using our PyTorch frontend.

Please refer to the PyTorch static quantization tutorial below to learn about their quantization workflow. https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html

We use this function to quantize PyTorch models. In short, this function takes a floating point model and converts it to uint8. The model is per-channel quantized.

def quantize_model(model, inp):
    model.fuse_model()
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)
    # Dummy calibration
    model(inp)
    torch.quantization.convert(model, inplace=True)

Load quantization-ready, pretrained Mobilenet v2 model from torchvision

We choose mobilenet v2 because this model was trained with quantization aware training. Other models require a full post training calibration.

qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth

  0%|          | 0.00/13.6M [00:00<?, ?B/s]
 59%|#####8    | 7.99M/13.6M [00:00<00:00, 52.4MB/s]
 96%|#########5| 13.0M/13.6M [00:00<00:00, 35.5MB/s]
100%|##########| 13.6M/13.6M [00:00<00:00, 39.1MB/s]

Quantize, trace and run the PyTorch Mobilenet v2 model

The details are out of scope for this tutorial. Please refer to the tutorials on the PyTorch website to learn about quantization and jit.

pt_inp = torch.from_numpy(inp)
quantize_model(qmodel, pt_inp)
script_module = torch.jit.trace(qmodel, pt_inp).eval()

with torch.no_grad():
    pt_result = script_module(pt_inp).numpy()
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/quantization/observer.py:1209: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point
  warnings.warn(

Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend

The PyTorch frontend has support for converting a quantized PyTorch model to an equivalent Relay module enriched with quantization-aware operators. We call this representation Relay QNN dialect.

You can print the output from the frontend to see how quantized models are represented.

You would see operators specific to quantization such as qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc.

input_name = "input"  # the input name can be be arbitrary for PyTorch frontend.
input_shapes = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
# print(mod) # comment in to see the QNN IR dump
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  other = LooseVersion(other)

Compile and run the Relay module

Once we obtained the quantized Relay module, the rest of the workflow is the same as running floating point models. Please refer to other tutorials for more details.

Under the hood, quantization specific operators are lowered to a sequence of standard Relay operators before compilation.

target = "llvm"
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target=target)

Compare the output labels

We should see identical labels printed.

pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]

print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels])
print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels])
PyTorch top3 labels: ['tiger cat', 'Egyptian cat', 'lynx, catamount']
TVM top3 labels: ['tiger cat', 'Egyptian cat', 'tabby, tabby cat']

However, due to the difference in numerics, in general the raw floating point outputs are not expected to be identical. Here, we print how many floating point output values are identical out of 1000 outputs from mobilenet v2.

print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0]))
122 in 1000 raw floating outputs identical.

Measure performance

Here we give an example of how to measure performance of TVM compiled models.

n_repeat = 100  # should be bigger to make the measurement more accurate
dev = tvm.cpu(0)
print(rt_mod.benchmark(dev, number=1, repeat=n_repeat))
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
  89.1260      89.1131      90.2300      88.7746       0.2328

Note

We recommend this method for the following reasons:

  • Measurements are done in C++, so there is no Python overhead

  • It includes several warm up runs

  • The same method can be used to profile on remote devices (android etc.).

Note

Unless the hardware has special support for fast 8 bit instructions, quantized models are not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does quantized convolution in 16 bit, even if the model itself is 8 bit.

For x86, the best performance can be achieved on CPUs with AVX512 instructions set. In this case, TVM utilizes the fastest available 8 bit instructions for the given target. This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer).

Moreover, the following general tips for CPU performance equally applies:

  • Set the environment variable TVM_NUM_THREADS to the number of physical cores

  • Choose the best target for your hardware, such as “llvm -mcpu=skylake-avx512” or “llvm -mcpu=cascadelake” (more CPUs with AVX512 would come in the future)

Deploy a quantized MXNet Model

TODO

Deploy a quantized TFLite Model

TODO

Total running time of the script: ( 1 minutes 34.099 seconds)

Gallery generated by Sphinx-Gallery