Compile OneFlow Models

Author: Xiaoyu Zhang

This article is an introductory tutorial to deploy OneFlow models with Relay.

For us to begin with, OneFlow package should be installed.

A quick solution is to install via pip

pip install flowvision==0.1.0
python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu

or please refer to official site: https://github.com/Oneflow-Inc/oneflow

Currently, TVM supports OneFlow 0.7.0. Other versions may be unstable.

import os, math
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

# oneflow imports
import flowvision
import oneflow as flow
import oneflow.nn as nn

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional_pil.py:193: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  def resize(img, size, interpolation=Image.BILINEAR):
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:65: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  Image.NEAREST: "nearest",
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:66: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  Image.BILINEAR: "bilinear",
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:67: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  Image.BICUBIC: "bicubic",
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:68: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
  Image.BOX: "box",
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:69: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
  Image.HAMMING: "hamming",
/usr/local/lib/python3.7/dist-packages/flowvision/transforms/functional.py:70: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
  Image.LANCZOS: "lanczos",
/usr/local/lib/python3.7/dist-packages/flowvision/data/auto_augment.py:28: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
/usr/local/lib/python3.7/dist-packages/flowvision/data/auto_augment.py:28: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)

Load a pretrained OneFlow model and save model

model_name = "resnet18"
model = getattr(flowvision.models, model_name)(pretrained=True)
model = model.eval()

model_dir = "resnet18_model"
if not os.path.exists(model_dir):
    flow.save(model.state_dict(), model_dir)
Downloading: "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip" to /workspace/.oneflow/flowvision_cache/resnet18.zip

  0%|          | 0.00/41.5M [00:00<?, ?B/s]
  0%|          | 16.0k/41.5M [00:00<08:10, 88.7kB/s]
  0%|          | 32.0k/41.5M [00:00<08:11, 88.4kB/s]
  0%|          | 48.0k/41.5M [00:00<08:12, 88.3kB/s]
  0%|          | 64.0k/41.5M [00:00<08:11, 88.3kB/s]
  0%|          | 80.0k/41.5M [00:00<08:11, 88.3kB/s]
  0%|          | 96.0k/41.5M [00:01<08:11, 88.3kB/s]
  0%|          | 112k/41.5M [00:01<08:11, 88.2kB/s]
  0%|          | 128k/41.5M [00:01<08:11, 88.2kB/s]
  0%|          | 144k/41.5M [00:01<08:11, 88.3kB/s]
  0%|          | 160k/41.5M [00:01<08:11, 88.2kB/s]
  0%|          | 176k/41.5M [00:02<08:10, 88.3kB/s]
  0%|          | 200k/41.5M [00:02<07:05, 102kB/s]
  1%|          | 224k/41.5M [00:02<06:29, 111kB/s]
  1%|          | 248k/41.5M [00:02<06:08, 117kB/s]
  1%|          | 272k/41.5M [00:02<05:54, 122kB/s]
  1%|          | 296k/41.5M [00:02<05:45, 125kB/s]
  1%|          | 320k/41.5M [00:03<05:39, 127kB/s]
  1%|          | 352k/41.5M [00:03<05:03, 142kB/s]
  1%|          | 384k/41.5M [00:03<04:43, 152kB/s]
  1%|          | 424k/41.5M [00:03<04:09, 173kB/s]
  1%|1         | 464k/41.5M [00:03<03:49, 187kB/s]
  1%|1         | 504k/41.5M [00:04<03:37, 197kB/s]
  1%|1         | 560k/41.5M [00:04<03:06, 230kB/s]
  1%|1         | 616k/41.5M [00:04<02:48, 254kB/s]
  2%|1         | 680k/41.5M [00:04<02:30, 284kB/s]
  2%|1         | 744k/41.5M [00:04<02:20, 304kB/s]
  2%|1         | 824k/41.5M [00:05<02:03, 345kB/s]
  2%|2         | 904k/41.5M [00:05<01:53, 374kB/s]
  2%|2         | 992k/41.5M [00:05<01:44, 408kB/s]
  3%|2         | 1.06M/41.5M [00:05<01:35, 444kB/s]
  3%|2         | 1.17M/41.5M [00:05<01:25, 496kB/s]
  3%|3         | 1.29M/41.5M [00:05<01:17, 546kB/s]
  3%|3         | 1.41M/41.5M [00:06<01:10, 594kB/s]
  4%|3         | 1.55M/41.5M [00:06<01:05, 641kB/s]
  4%|4         | 1.70M/41.5M [00:06<00:59, 700kB/s]
  4%|4         | 1.86M/41.5M [00:06<00:54, 768kB/s]
  5%|4         | 2.03M/41.5M [00:06<00:49, 829kB/s]
  5%|5         | 2.21M/41.5M [00:07<00:46, 885kB/s]
  6%|5         | 2.41M/41.5M [00:07<00:42, 963kB/s]
  6%|6         | 2.62M/41.5M [00:07<00:39, 1.03MB/s]
  7%|6         | 2.86M/41.5M [00:07<00:36, 1.12MB/s]
  7%|7         | 3.11M/41.5M [00:07<00:33, 1.21MB/s]
  8%|8         | 3.38M/41.5M [00:07<00:30, 1.30MB/s]
  9%|8         | 3.66M/41.5M [00:08<00:28, 1.40MB/s]
 10%|9         | 3.97M/41.5M [00:08<00:26, 1.49MB/s]
 10%|#         | 4.30M/41.5M [00:08<00:24, 1.60MB/s]
 11%|#1        | 4.64M/41.5M [00:08<00:22, 1.70MB/s]
 12%|#2        | 5.00M/41.5M [00:08<00:21, 1.80MB/s]
 13%|#2        | 5.38M/41.5M [00:09<00:19, 1.90MB/s]
 14%|#3        | 5.78M/41.5M [00:09<00:18, 2.02MB/s]
 15%|#5        | 6.23M/41.5M [00:09<00:17, 2.17MB/s]
 16%|#6        | 6.69M/41.5M [00:09<00:15, 2.30MB/s]
 17%|#7        | 7.19M/41.5M [00:09<00:14, 2.45MB/s]
 19%|#8        | 7.71M/41.5M [00:10<00:13, 2.60MB/s]
 20%|#9        | 8.27M/41.5M [00:10<00:12, 2.76MB/s]
 21%|##1       | 8.85M/41.5M [00:10<00:11, 2.93MB/s]
 23%|##2       | 9.46M/41.5M [00:10<00:09, 3.40MB/s]
 24%|##4       | 10.1M/41.5M [00:10<00:08, 3.97MB/s]
 25%|##5       | 10.5M/41.5M [00:10<00:08, 3.68MB/s]
 26%|##6       | 10.9M/41.5M [00:10<00:10, 3.17MB/s]
 28%|##7       | 11.5M/41.5M [00:11<00:08, 3.64MB/s]
 29%|##9       | 12.2M/41.5M [00:11<00:06, 4.53MB/s]
 31%|###       | 12.7M/41.5M [00:11<00:07, 4.18MB/s]
 32%|###1      | 13.1M/41.5M [00:11<00:08, 3.58MB/s]
 33%|###3      | 13.8M/41.5M [00:11<00:07, 3.79MB/s]
 35%|###5      | 14.7M/41.5M [00:11<00:06, 4.15MB/s]
 38%|###7      | 15.6M/41.5M [00:12<00:06, 4.46MB/s]
 40%|###9      | 16.5M/41.5M [00:12<00:05, 4.77MB/s]
 42%|####2     | 17.5M/41.5M [00:12<00:04, 5.45MB/s]
 45%|####4     | 18.5M/41.5M [00:12<00:04, 5.99MB/s]
 46%|####6     | 19.1M/41.5M [00:12<00:03, 5.94MB/s]
 48%|####7     | 19.7M/41.5M [00:12<00:04, 5.04MB/s]
 50%|#####     | 20.8M/41.5M [00:12<00:03, 6.16MB/s]
 53%|#####3    | 22.0M/41.5M [00:13<00:02, 6.90MB/s]
 55%|#####4    | 22.7M/41.5M [00:13<00:02, 6.84MB/s]
 56%|#####6    | 23.4M/41.5M [00:13<00:03, 5.77MB/s]
 59%|#####9    | 24.7M/41.5M [00:13<00:02, 7.10MB/s]
 63%|######2   | 26.0M/41.5M [00:13<00:01, 8.34MB/s]
 65%|######4   | 26.9M/41.5M [00:13<00:01, 7.90MB/s]
 67%|######6   | 27.6M/41.5M [00:13<00:02, 6.67MB/s]
 70%|######9   | 29.0M/41.5M [00:14<00:01, 7.93MB/s]
 73%|#######3  | 30.4M/41.5M [00:14<00:01, 9.11MB/s]
 76%|#######5  | 31.4M/41.5M [00:14<00:01, 8.62MB/s]
 78%|#######7  | 32.2M/41.5M [00:14<00:01, 7.26MB/s]
 81%|########  | 33.4M/41.5M [00:14<00:01, 8.09MB/s]
 84%|########4 | 34.9M/41.5M [00:14<00:00, 9.22MB/s]
 86%|########6 | 35.8M/41.5M [00:14<00:00, 8.71MB/s]
 88%|########8 | 36.7M/41.5M [00:15<00:00, 7.34MB/s]
 91%|#########1| 37.9M/41.5M [00:15<00:00, 8.10MB/s]
 95%|#########4| 39.3M/41.5M [00:15<00:00, 9.40MB/s]
 97%|#########7| 40.2M/41.5M [00:15<00:00, 8.73MB/s]
 99%|#########9| 41.1M/41.5M [00:15<00:00, 7.37MB/s]
100%|##########| 41.5M/41.5M [00:15<00:00, 2.79MB/s]

Load a test image

Classic cat example!

from PIL import Image

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))

# Preprocess the image and convert to tensor
from flowvision import transforms

my_preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = my_preprocess(img)
img = np.expand_dims(img.numpy(), 0)

Import the graph to Relay

Convert OneFlow graph to Relay graph. The input name can be arbitrary.

class Graph(flow.nn.Graph):
    def __init__(self, module):
        super().__init__()
        self.m = module

    def build(self, x):
        out = self.m(x)
        return out


graph = Graph(model)
_ = graph._compile(flow.randn(1, 3, 224, 224))

mod, params = relay.frontend.from_oneflow(graph, model_dir)

Relay Build

Compile the graph to llvm target with given input specification.

target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)
/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "

Execute the portable graph on TVM

Now we can try deploying the compiled model on target.

target = "cuda"
with tvm.transform.PassContext(opt_level=10):
    intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target)

print(type(img))
print(img.shape)
tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params)
<class 'numpy.ndarray'>
(1, 3, 224, 224)

Look up synset name

Look up prediction top 1 index in 1000 class synset.

synset_url = "".join(
    [
        "https://raw.githubusercontent.com/Cadene/",
        "pretrained-models.pytorch/master/data/",
        "imagenet_synsets.txt",
    ]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synsets = f.readlines()

synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}

class_url = "".join(
    [
        "https://raw.githubusercontent.com/Cadene/",
        "pretrained-models.pytorch/master/data/",
        "imagenet_classes.txt",
    ]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
    class_id_to_key = f.readlines()

class_id_to_key = [x.strip() for x in class_id_to_key]

# Get top-1 result for TVM
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]

# Convert input to OneFlow variable and get OneFlow result for comparison
with flow.no_grad():
    torch_img = flow.from_numpy(img)
    output = model(torch_img)

    # Get top-1 result for OneFlow
    top_oneflow = np.argmax(output.numpy())
    oneflow_class_key = class_id_to_key[top_oneflow]

print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print(
    "OneFlow top-1 id: {}, class name: {}".format(top_oneflow, key_to_classname[oneflow_class_key])
)
Relay top-1 id: 281, class name: tabby, tabby cat
OneFlow top-1 id: 281, class name: tabby, tabby cat

Gallery generated by Sphinx-Gallery