4. microTVM PyTorch Tutorial

Authors: Mehrdad Hessar

This tutorial is showcasing microTVM host-driven AoT compilation with a PyTorch model. This tutorial can be executed on a x86 CPU using C runtime (CRT).

Note: This tutorial only runs on x86 CPU using CRT and does not run on Zephyr since the model would not fit on our current supported Zephyr boards.

Install microTVM Python dependencies

TVM does not include a package for Python serial communication, so we must install one before using microTVM. We will also need TFLite to load models.

%%shell
pip install pyserial==3.5 tflite==2.1
import pathlib
import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor
import tvm.micro.testing

Load a pre-trained PyTorch model

To begin with, load pre-trained MobileNetV2 from torchvision. Then, download a cat image and preprocess it to use as the model input.

model = torchvision.models.quantization.mobilenet_v2(weights="DEFAULT", quantize=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))

# Preprocess the image and convert to tensor
my_preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

input_name = "input0"
shape_list = [(input_name, input_shape)]
relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/quantization/utils.py:376: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  warnings.warn(
Downloading: "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2_qnnpack_37f702c5.pth

  0%|          | 0.00/3.42M [00:00<?, ?B/s]
100%|##########| 3.42M/3.42M [00:00<00:00, 83.5MB/s]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/nn/quantized/modules/__init__.py:98: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return torch.quantize_per_tensor(X, float(self.scale),
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/nn/quantized/modules/__init__.py:99: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  int(self.zero_point), self.dtype)

Define Target, Runtime and Executor

In this tutorial we use AOT host-driven executor. To compile the model for an emulated embedded environment on an x86 machine we use C runtime (CRT) and we use host micro target. Using this setup, TVM compiles the model for C runtime which can run on a x86 CPU machine with the same flow that would run on a physical microcontroller. CRT Uses the main() from src/runtime/crt/host/main.cc To use physical hardware, replace board with another physical micro target, e.g. nrf5340dk_nrf5340_cpuapp or mps2_an521 and change the platform type to Zephyr. See more target examples in Training Vision Models for microTVM on Arduino and microTVM TFLite Tutorial.

target = tvm.micro.testing.get_target(platform="crt", board=None)

# Use the C runtime (crt) and enable static linking by setting system-lib to True
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})

# Use the AOT executor rather than graph or vm executors. Don't use unpacked API or C calling style.
executor = Executor("aot")

Compile the model

Now, we compile the model for the target:

with tvm.transform.PassContext(
    opt_level=3,
    config={"tir.disable_vectorize": True},
):
    module = tvm.relay.build(
        relay_mod, target=target, runtime=runtime, executor=executor, params=params
    )

Create a microTVM project

Now that we have the compiled model as an IRModule, we need to create a firmware project to use the compiled model with microTVM. To do this, we use Project API.

Build, flash and execute the model

Next, we build the microTVM project and flash it. Flash step is specific to physical microcontroller and it is skipped if it is simulating a microcontroller via the host main.cc` or if a Zephyr emulated board is selected as the target.

project.build()
project.flash()

input_data = {input_name: tvm.nd.array(img.astype("float32"))}
with tvm.micro.Session(project.transport()) as session:
    aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
    aot_executor.set_input(**input_data)
    aot_executor.run()
    result = aot_executor.get_output(0).numpy()

Look up synset name

Look up prediction top 1 index in 1000 class synset.

synset_url = (
    "https://raw.githubusercontent.com/Cadene/"
    "pretrained-models.pytorch/master/data/"
    "imagenet_synsets.txt"
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synsets = f.readlines()

synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}

class_url = (
    "https://raw.githubusercontent.com/Cadene/"
    "pretrained-models.pytorch/master/data/"
    "imagenet_classes.txt"
)
class_path = download_testdata(class_url, "imagenet_classes.txt", module="data")
with open(class_path) as f:
    class_id_to_key = f.readlines()

class_id_to_key = [x.strip() for x in class_id_to_key]

# Get top-1 result for TVM
top1_tvm = np.argmax(result)
tvm_class_key = class_id_to_key[top1_tvm]

# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():
    torch_img = torch.from_numpy(img)
    output = model(torch_img)

    # Get top-1 result for PyTorch
    top1_torch = np.argmax(output.numpy())
    torch_class_key = class_id_to_key[top1_torch]

print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
Relay top-1 id: 282, class name: tiger cat
Torch top-1 id: 282, class name: tiger cat

Total running time of the script: ( 1 minutes 32.330 seconds)

Gallery generated by Sphinx-Gallery