microTVM PyTorch Tutorial

Authors: Mehrdad Hessar

This tutorial is showcasing microTVM host-driven AoT compilation with a PyTorch model. This tutorial can be executed on a x86 CPU using C runtime (CRT).

Note: This tutorial only runs on x86 CPU using CRT and does not run on Zephyr since the model would not fit on our current supported Zephyr boards.

import pathlib

import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor

Load a pre-trained PyTorch model

To begin with, load pre-trained MobileNetV2 from torchvision. Then, download a cat image and preprocess it to use as the model input.

model = torchvision.models.quantization.mobilenet_v2(weights="DEFAULT", quantize=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))

# Preprocess the image and convert to tensor
my_preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

input_name = "input0"
shape_list = [(input_name, input_shape)]
relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torch/ao/quantization/utils.py:281: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  "must run observer before calling calculate_qparams. " +
Downloading: "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2_qnnpack_37f702c5.pth

  0%|          | 0.00/3.42M [00:00<?, ?B/s]
100%|##########| 3.42M/3.42M [00:00<00:00, 78.4MB/s]
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  other = LooseVersion(other)

Define Target, Runtime and Executor

In this tutorial we use AOT host-driven executor. To compile the model for an emulated embedded environment on an x86 machine we use C runtime (CRT) and we use host micro target. Using this setup, TVM compiles the model for C runtime which can run on a x86 CPU machine with the same flow that would run on a physical microcontroller.

# Simulate a microcontroller on the host machine. Uses the main() from `src/runtime/crt/host/main.cc`
# To use physical hardware, replace "host" with another physical micro target, e.g. `nrf52840`
# or `mps2_an521`. See more more target examples in micro_train.py and micro_tflite.py tutorials.
target = tvm.target.target.micro("host")

# Use the C runtime (crt) and enable static linking by setting system-lib to True
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})

# Use the AOT executor rather than graph or vm executors. Don't use unpacked API or C calling style.
executor = Executor("aot")

Compile the model

Now, we compile the model for the target:

with tvm.transform.PassContext(
    opt_level=3,
    config={"tir.disable_vectorize": True},
):
    module = tvm.relay.build(
        relay_mod, target=target, runtime=runtime, executor=executor, params=params
    )

Create a microTVM project

Now that we have the compiled model as an IRModule, we need to create a firmware project to use the compiled model with microTVM. To do this, we use Project API.

Build, flash and execute the model

Next, we build the microTVM project and flash it. Flash step is specific to physical microcontroller and it is skipped if it is simulating a microcontroller via the host main.cc` or if a Zephyr emulated board is selected as the target.

project.build()
project.flash()

input_data = {input_name: tvm.nd.array(img.astype("float32"))}
with tvm.micro.Session(project.transport()) as session:
    aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
    aot_executor.set_input(**input_data)
    aot_executor.run()
    result = aot_executor.get_output(0).numpy()

Look up synset name

Look up prediction top 1 index in 1000 class synset.

synset_url = (
    "https://raw.githubusercontent.com/Cadene/"
    "pretrained-models.pytorch/master/data/"
    "imagenet_synsets.txt"
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synsets = f.readlines()

synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}

class_url = (
    "https://raw.githubusercontent.com/Cadene/"
    "pretrained-models.pytorch/master/data/"
    "imagenet_classes.txt"
)
class_path = download_testdata(class_url, "imagenet_classes.txt", module="data")
with open(class_path) as f:
    class_id_to_key = f.readlines()

class_id_to_key = [x.strip() for x in class_id_to_key]

# Get top-1 result for TVM
top1_tvm = np.argmax(result)
tvm_class_key = class_id_to_key[top1_tvm]

# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():
    torch_img = torch.from_numpy(img)
    output = model(torch_img)

    # Get top-1 result for PyTorch
    top1_torch = np.argmax(output.numpy())
    torch_class_key = class_id_to_key[top1_torch]

print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
Relay top-1 id: 282, class name: tiger cat
Torch top-1 id: 282, class name: tiger cat

Total running time of the script: ( 1 minutes 2.799 seconds)

Gallery generated by Sphinx-Gallery