Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Compile PyTorch Models¶
Author: Alex Wong
This article is an introductory tutorial to deploy PyTorch models with Relay.
For us to begin, PyTorch should be installed. TorchVision is also required so we can use the model zoo. A quick solution is to install via pip:
pip install torch
pip install torchvision
or please refer to official site https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used with the proper TorchVision version.
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may be unstable.
import tvm
from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
# PyTorch imports
import torch
import torchvision
Load a pretrained PyTorch model¶
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /workspace/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
23%|##2 | 10.1M/44.7M [00:00<00:00, 66.8MB/s]
37%|###6 | 16.4M/44.7M [00:00<00:00, 66.0MB/s]
58%|#####8 | 26.1M/44.7M [00:00<00:00, 76.8MB/s]
75%|#######4 | 33.4M/44.7M [00:00<00:00, 54.9MB/s]
88%|########7 | 39.3M/44.7M [00:00<00:00, 55.7MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 61.8MB/s]
Load a test image¶
Classic cat example!
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# Preprocess the image and convert to tensor
from torchvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
Import the graph to Relay¶
Convert PyTorch graph to Relay graph. The input name can be arbitrary.
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Relay Build¶
Compile the graph to llvm target with given input specification.
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
Execute the portable graph on TVM¶
Now we can try deploying the compiled model on target.
from tvm.contrib import graph_executor
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# Set inputs
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# Execute
m.run()
# Get outputs
tvm_output = m.get_output(0)
Look up synset name¶
Look up prediction top 1 index in 1000 class synset.
synset_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_synsets.txt",
]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synsets = f.readlines()
synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
class_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_classes.txt",
]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
class_id_to_key = f.readlines()
class_id_to_key = [x.strip() for x in class_id_to_key]
# Get top-1 result for TVM
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]
# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():
torch_img = torch.from_numpy(img)
output = model(torch_img)
# Get top-1 result for PyTorch
top1_torch = np.argmax(output.numpy())
torch_class_key = class_id_to_key[top1_torch]
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
Relay top-1 id: 281, class name: tabby, tabby cat
Torch top-1 id: 281, class name: tabby, tabby cat