Compile PyTorch Object Detection Models

This article is an introductory tutorial to deploy PyTorch object detection models with Relay VM.

For us to begin with, PyTorch should be installed. TorchVision is also required since we will be using it as our model zoo.

A quick solution is to install via pip

pip install torch
pip install torchvision

or please refer to official site https://pytorch.org/get-started/locally/

PyTorch versions should be backwards compatible but should be used with the proper TorchVision version.

Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may be unstable.

import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdata

import numpy as np
import cv2

# PyTorch imports
import torch
import torchvision

Load pre-trained maskrcnn from torchvision and do tracing

in_size = 300

input_shape = (1, 3, in_size, in_size)


def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace


def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])


model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

  0%|          | 0.00/170M [00:00<?, ?B/s]
 10%|9         | 16.2M/170M [00:00<00:00, 170MB/s]
 24%|##3       | 39.9M/170M [00:00<00:00, 216MB/s]
 38%|###7      | 64.3M/170M [00:00<00:00, 234MB/s]
 53%|#####2    | 89.9M/170M [00:00<00:00, 247MB/s]
 69%|######8   | 117M/170M [00:00<00:00, 260MB/s]
 85%|########5 | 145M/170M [00:00<00:00, 272MB/s]
100%|##########| 170M/170M [00:00<00:00, 260MB/s]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py:75: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for img in images:
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:110: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  images = [img for img in images]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:155: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:156: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/nn/functional.py:3912: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/nn/functional.py:3912: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/_utils.py:176: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if box_sum > 0:
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/_utils.py:216: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/_utils.py:217: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/_utils.py:179: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if box_sum > 0:
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/_utils.py:519: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/rpn.py:275: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:156: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:157: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:157: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:158: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:159: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:159: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:72: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/poolers.py:82: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/poolers.py:185: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  num_rois = len(rois)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/__init__.py:1209: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:298: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(s, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:298: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(s, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:299: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:299: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py:389: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py:389: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)

Download a test image and pre-process

img_url = (
    "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, "test_street_small.jpg", module="data")

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

Import the graph to Relay

/workspace/python/tvm/relay/build_module.py:345: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
  warnings.warn(

Compile with Relay VM

Note: Currently only CPU target is supported. For x86 target, it is highly recommended to build TVM with Intel MKL and Intel OpenMP to get best performance, due to the existence of large dense operator in torchvision rcnn models.

# Add "-libs=mkl" to get best performance on x86 target.
# For x86 machine supports AVX512, the complete target is
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
    vm_exec = relay.vm.compile(mod, target=target, params=params)

Inference with Relay VM

dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()

Get boxes with score larger than 0.9

score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):
    if score > score_threshold:
        valid_boxes.append(boxes[i])
    else:
        break

print("Get {} valid boxes".format(len(valid_boxes)))
Get 9 valid boxes

Total running time of the script: ( 3 minutes 29.276 seconds)

Gallery generated by Sphinx-Gallery