Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Compile PyTorch Object Detection Models¶
This article is an introductory tutorial to deploy PyTorch object detection models with Relay VM.
For us to begin with, PyTorch should be installed. TorchVision is also required since we will be using it as our model zoo.
A quick solution is to install via pip
pip install torch
pip install torchvision
or please refer to official site https://pytorch.org/get-started/locally/
PyTorch versions should be backwards compatible but should be used with the proper TorchVision version.
Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may be unstable.
import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdata
import numpy as np
import cv2
# PyTorch imports
import torch
import torchvision
Load pre-trained maskrcnn from torchvision and do tracing¶
in_size = 300
input_shape = (1, 3, in_size, in_size)
def do_trace(model, inp):
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace
def dict_to_tuple(out_dict):
if "masks" in out_dict.keys():
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return dict_to_tuple(out[0])
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
with torch.no_grad():
out = model(inp)
script_module = do_trace(model, inp)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
0%| | 0.00/170M [00:00<?, ?B/s]
4%|3 | 6.30M/170M [00:00<00:03, 45.5MB/s]
6%|6 | 10.7M/170M [00:00<00:03, 45.9MB/s]
9%|9 | 16.0M/170M [00:00<00:03, 40.4MB/s]
14%|#4 | 24.0M/170M [00:00<00:03, 47.4MB/s]
18%|#7 | 30.3M/170M [00:00<00:04, 33.0MB/s]
21%|##1 | 36.5M/170M [00:00<00:03, 39.4MB/s]
24%|##4 | 40.9M/170M [00:01<00:03, 40.3MB/s]
27%|##7 | 46.3M/170M [00:01<00:04, 32.2MB/s]
30%|##9 | 50.3M/170M [00:01<00:03, 34.1MB/s]
35%|###4 | 58.8M/170M [00:01<00:02, 46.5MB/s]
38%|###7 | 64.0M/170M [00:01<00:02, 44.2MB/s]
42%|####2 | 72.0M/170M [00:01<00:02, 50.7MB/s]
47%|####6 | 79.0M/170M [00:01<00:01, 56.3MB/s]
50%|####9 | 84.8M/170M [00:02<00:01, 54.2MB/s]
55%|#####5 | 93.5M/170M [00:02<00:01, 63.9MB/s]
59%|#####8 | 100M/170M [00:02<00:01, 52.5MB/s]
62%|######2 | 106M/170M [00:02<00:01, 47.7MB/s]
66%|######5 | 112M/170M [00:02<00:01, 47.9MB/s]
71%|####### | 120M/170M [00:02<00:01, 50.9MB/s]
74%|#######4 | 126M/170M [00:02<00:00, 46.1MB/s]
78%|#######7 | 132M/170M [00:03<00:00, 49.6MB/s]
81%|######## | 137M/170M [00:03<00:00, 44.3MB/s]
85%|########4 | 144M/170M [00:03<00:00, 43.2MB/s]
88%|########8 | 150M/170M [00:03<00:00, 45.7MB/s]
91%|#########1| 155M/170M [00:03<00:00, 43.4MB/s]
94%|#########4| 160M/170M [00:03<00:00, 41.1MB/s]
98%|#########7| 166M/170M [00:03<00:00, 44.5MB/s]
100%|##########| 170M/170M [00:03<00:00, 45.6MB/s]
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torch/nn/functional.py:3897: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
for i in range(dim)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/detection/anchor_utils.py:124: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
for g in grid_sizes
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/detection/rpn.py:99: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
A = Ax4 // 4
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/detection/rpn.py:100: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
C = AxC // A
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/ops/boxes.py:157: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/ops/boxes.py:159: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torch/__init__.py:833: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert condition, message
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/detection/transform.py:300: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
for s, s_orig in zip(new_size, original_size)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/detection/roi_heads.py:389: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
Download a test image and pre-process¶
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, "test_street_small.jpg", module="data")
img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)
Import the graph to Relay¶
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
/workspace/python/tvm/relay/build_module.py:348: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
DeprecationWarning,
/workspace/python/tvm/relay/frontend/pytorch.py:451: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
inputs[3], lambda ret: ret.astype(np.int).item(0)
Compile with Relay VM¶
Note: Currently only CPU target is supported. For x86 target, it is highly recommended to build TVM with Intel MKL and Intel OpenMP to get best performance, due to the existence of large dense operator in torchvision rcnn models.
# Add "-libs=mkl" to get best performance on x86 target.
# For x86 machine supports AVX512, the complete target is
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"
with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)
Inference with Relay VM¶
dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()
Get boxes with score larger than 0.9¶
score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):
if score > score_threshold:
valid_boxes.append(boxes[i])
else:
break
print("Get {} valid boxes".format(len(valid_boxes)))
Get 9 valid boxes
Total running time of the script: ( 3 minutes 50.304 seconds)