Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Compile OneFlow Models¶
Author: Xiaoyu Zhang
This article is an introductory tutorial to deploy OneFlow models with Relay.
For us to begin with, OneFlow package should be installed.
A quick solution is to install via pip
pip install flowvision==0.1.0
pip install -f https://release.oneflow.info oneflow==0.7.0+cpu
or please refer to official site: https://github.com/Oneflow-Inc/oneflow
Currently, TVM supports OneFlow 0.7.0. Other versions may be unstable.
import os, math
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
# oneflow imports
import flowvision
import oneflow as flow
import oneflow.nn as nn
import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/oneflow/framework/dtype.py:48: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
oneflow.bool: np.bool,
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional_pil.py:193: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
def resize(img, size, interpolation=Image.BILINEAR):
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:65: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
Image.NEAREST: "nearest",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:66: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
Image.BILINEAR: "bilinear",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:67: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
Image.BICUBIC: "bicubic",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:68: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
Image.BOX: "box",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:69: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
Image.HAMMING: "hamming",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/transforms/functional.py:70: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
Image.LANCZOS: "lanczos",
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/data/auto_augment.py:28: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/flowvision/data/auto_augment.py:28: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
Load a pretrained OneFlow model and save model¶
model_name = "resnet18"
model = getattr(flowvision.models, model_name)(pretrained=True)
model = model.eval()
model_dir = "resnet18_model"
if not os.path.exists(model_dir):
flow.save(model.state_dict(), model_dir)
Downloading: "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip" to /workspace/.oneflow/flowvision_cache/resnet18.zip
0%| | 0.00/41.5M [00:00<?, ?B/s]
15%|#5 | 6.33M/41.5M [00:00<00:01, 25.3MB/s]
21%|##1 | 8.74M/41.5M [00:00<00:01, 25.2MB/s]
39%|###8 | 16.0M/41.5M [00:00<00:00, 34.4MB/s]
54%|#####3 | 22.3M/41.5M [00:00<00:00, 39.2MB/s]
63%|######2 | 26.1M/41.5M [00:00<00:00, 36.1MB/s]
77%|#######7 | 32.0M/41.5M [00:00<00:00, 40.8MB/s]
87%|########6 | 36.0M/41.5M [00:01<00:00, 36.6MB/s]
96%|#########6| 40.0M/41.5M [00:01<00:00, 34.9MB/s]
100%|##########| 41.5M/41.5M [00:01<00:00, 36.0MB/s]
Load a test image¶
Classic cat example!
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# Preprocess the image and convert to tensor
from flowvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img.numpy(), 0)
Import the graph to Relay¶
Convert OneFlow graph to Relay graph. The input name can be arbitrary.
class Graph(flow.nn.Graph):
def __init__(self, module):
super().__init__()
self.m = module
def build(self, x):
out = self.m(x)
return out
graph = Graph(model)
_ = graph._compile(flow.randn(1, 3, 224, 224))
mod, params = relay.frontend.from_oneflow(graph, model_dir)
Relay Build¶
Compile the graph to llvm target with given input specification.
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
Execute the portable graph on TVM¶
Now we can try deploying the compiled model on target.
target = "cuda"
with tvm.transform.PassContext(opt_level=10):
intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target)
print(type(img))
print(img.shape)
tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params)
<class 'numpy.ndarray'>
(1, 3, 224, 224)
Look up synset name¶
Look up prediction top 1 index in 1000 class synset.
synset_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_synsets.txt",
]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synsets = f.readlines()
synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
class_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_classes.txt",
]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
class_id_to_key = f.readlines()
class_id_to_key = [x.strip() for x in class_id_to_key]
# Get top-1 result for TVM
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]
# Convert input to OneFlow variable and get OneFlow result for comparison
with flow.no_grad():
torch_img = flow.from_numpy(img)
output = model(torch_img)
# Get top-1 result for OneFlow
top_oneflow = np.argmax(output.numpy())
oneflow_class_key = class_id_to_key[top_oneflow]
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print(
"OneFlow top-1 id: {}, class name: {}".format(top_oneflow, key_to_classname[oneflow_class_key])
)
Relay top-1 id: 281, class name: tabby, tabby cat
OneFlow top-1 id: 281, class name: tabby, tabby cat