.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_deploy_models_deploy_object_detection_pytorch.py: Compile PyTorch Object Detection Models ======================================= This article is an introductory tutorial to deploy PyTorch object detection models with Relay VM. For us to begin with, PyTorch should be installed. TorchVision is also required since we will be using it as our model zoo. A quick solution is to install via pip .. code-block:: bash pip install torch==1.7.0 pip install torchvision==0.8.1 or please refer to official site https://pytorch.org/get-started/locally/ PyTorch versions should be backwards compatible but should be used with the proper TorchVision version. Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may be unstable. .. code-block:: default import tvm from tvm import relay from tvm import relay from tvm.runtime.vm import VirtualMachine from tvm.contrib.download import download_testdata import numpy as np import cv2 # PyTorch imports import torch import torchvision Load pre-trained maskrcnn from torchvision and do tracing --------------------------------------------------------- .. code-block:: default in_size = 300 input_shape = (1, 3, in_size, in_size) def do_trace(model, inp): model_trace = torch.jit.trace(model, inp) model_trace.eval() return model_trace def dict_to_tuple(out_dict): if "masks" in out_dict.keys(): return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] return out_dict["boxes"], out_dict["scores"], out_dict["labels"] class TraceWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, inp): out = self.model(inp) return dict_to_tuple(out[0]) model_func = torchvision.models.detection.maskrcnn_resnet50_fpn model = TraceWrapper(model_func(pretrained=True)) model.eval() inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) with torch.no_grad(): out = model(inp) script_module = do_trace(model, inp) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/lib/python3.6/dist-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results). 'incorrect results).', category=RuntimeWarning) /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). dtype=torch.float32)).float())) for i in range(dim)] /usr/local/lib/python3.6/dist-packages/torchvision/models/detection/anchor_utils.py:147: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] /usr/local/lib/python3.6/dist-packages/torchvision/ops/boxes.py:128: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.6/dist-packages/torchvision/ops/boxes.py:130: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.6/dist-packages/torchvision/models/detection/transform.py:271: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). for s, s_orig in zip(new_size, original_size) /usr/local/lib/python3.6/dist-packages/torchvision/models/detection/roi_heads.py:372: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) Download a test image and pre-process ------------------------------------- .. code-block:: default img_url = ( "https://raw.githubusercontent.com/dmlc/web-data/" "master/gluoncv/detection/street_small.jpg" ) img_path = download_testdata(img_url, "test_street_small.jpg", module="data") img = cv2.imread(img_path).astype("float32") img = cv2.resize(img, (in_size, in_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.transpose(img / 255.0, [2, 0, 1]) img = np.expand_dims(img, axis=0) Import the graph to Relay ------------------------- .. code-block:: default input_name = "input0" shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(script_module, shape_list) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /workspace/python/tvm/relay/build_module.py:333: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function) DeprecationWarning, Compile with Relay VM --------------------- Note: Currently only CPU target is supported. For x86 target, it is highly recommended to build TVM with Intel MKL and Intel OpenMP to get best performance, due to the existence of large dense operator in torchvision rcnn models. .. code-block:: default # Add "-libs=mkl" to get best performance on x86 target. # For x86 machine supports AVX512, the complete target is # "llvm -mcpu=skylake-avx512 -libs=mkl" target = "llvm" with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) Inference with Relay VM ----------------------- .. code-block:: default dev = tvm.cpu() vm = VirtualMachine(vm_exec, dev) vm.set_input("main", **{input_name: img}) tvm_res = vm.run() Get boxes with score larger than 0.9 ------------------------------------ .. code-block:: default score_threshold = 0.9 boxes = tvm_res[0].numpy().tolist() valid_boxes = [] for i, score in enumerate(tvm_res[1].numpy().tolist()): if score > score_threshold: valid_boxes.append(boxes[i]) else: break print("Get {} valid boxes".format(len(valid_boxes))) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Get 9 valid boxes .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 24.172 seconds) .. _sphx_glr_download_how_to_deploy_models_deploy_object_detection_pytorch.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: deploy_object_detection_pytorch.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: deploy_object_detection_pytorch.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_