.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "how_to/deploy_models/deploy_object_detection_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download_how_to_deploy_models_deploy_object_detection_pytorch.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_deploy_models_deploy_object_detection_pytorch.py: Compile PyTorch Object Detection Models ======================================= This article is an introductory tutorial to deploy PyTorch object detection models with Relay VM. For us to begin with, PyTorch should be installed. TorchVision is also required since we will be using it as our model zoo. A quick solution is to install via pip .. code-block:: bash pip install torch==1.7.0 pip install torchvision==0.8.1 or please refer to official site https://pytorch.org/get-started/locally/ PyTorch versions should be backwards compatible but should be used with the proper TorchVision version. Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may be unstable. .. GENERATED FROM PYTHON SOURCE LINES 42-57 .. code-block:: default import tvm from tvm import relay from tvm import relay from tvm.runtime.vm import VirtualMachine from tvm.contrib.download import download_testdata import numpy as np import cv2 # PyTorch imports import torch import torchvision .. GENERATED FROM PYTHON SOURCE LINES 63-65 Load pre-trained maskrcnn from torchvision and do tracing --------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 65-102 .. code-block:: default in_size = 300 input_shape = (1, 3, in_size, in_size) def do_trace(model, inp): model_trace = torch.jit.trace(model, inp) model_trace.eval() return model_trace def dict_to_tuple(out_dict): if "masks" in out_dict.keys(): return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] return out_dict["boxes"], out_dict["scores"], out_dict["labels"] class TraceWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, inp): out = self.model(inp) return dict_to_tuple(out[0]) model_func = torchvision.models.detection.maskrcnn_resnet50_fpn model = TraceWrapper(model_func(pretrained=True)) model.eval() inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) with torch.no_grad(): out = model(inp) script_module = do_trace(model, inp) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth 0%| | 0.00/170M [00:00<?, ?B/s] 2%|1 | 2.94M/170M [00:00<00:05, 30.4MB/s] 4%|4 | 7.18M/170M [00:00<00:04, 38.6MB/s] 7%|7 | 12.3M/170M [00:00<00:03, 45.7MB/s] 10%|# | 17.0M/170M [00:00<00:03, 46.8MB/s] 13%|#2 | 21.5M/170M [00:00<00:03, 46.5MB/s] 15%|#5 | 25.9M/170M [00:00<00:03, 43.3MB/s] 18%|#7 | 30.1M/170M [00:00<00:03, 40.9MB/s] 20%|## | 34.0M/170M [00:00<00:03, 37.9MB/s] 22%|##2 | 37.7M/170M [00:01<00:03, 34.7MB/s] 25%|##5 | 43.2M/170M [00:01<00:03, 40.9MB/s] 28%|##7 | 47.5M/170M [00:01<00:03, 41.8MB/s] 30%|### | 51.5M/170M [00:01<00:03, 36.5MB/s] 32%|###2 | 55.2M/170M [00:01<00:03, 35.4MB/s] 35%|###4 | 58.7M/170M [00:01<00:03, 30.2MB/s] 36%|###6 | 61.9M/170M [00:01<00:03, 31.0MB/s] 39%|###9 | 66.7M/170M [00:01<00:03, 35.8MB/s] 42%|####1 | 71.2M/170M [00:01<00:02, 38.9MB/s] 44%|####4 | 75.1M/170M [00:02<00:02, 39.0MB/s] 46%|####6 | 78.9M/170M [00:02<00:02, 36.0MB/s] 49%|####8 | 83.1M/170M [00:02<00:02, 37.9MB/s] 51%|#####1 | 87.2M/170M [00:02<00:02, 39.0MB/s] 54%|#####3 | 91.5M/170M [00:02<00:02, 41.0MB/s] 56%|#####6 | 95.7M/170M [00:02<00:01, 41.4MB/s] 59%|#####8 | 100M/170M [00:02<00:01, 42.9MB/s] 61%|######1 | 104M/170M [00:02<00:01, 43.1MB/s] 64%|######4 | 109M/170M [00:02<00:01, 44.8MB/s] 67%|######6 | 113M/170M [00:03<00:01, 43.5MB/s] 70%|######9 | 119M/170M [00:03<00:01, 47.8MB/s] 73%|#######2 | 123M/170M [00:03<00:01, 47.8MB/s] 75%|#######5 | 128M/170M [00:03<00:01, 43.5MB/s] 78%|#######7 | 132M/170M [00:03<00:01, 33.6MB/s] 80%|######## | 137M/170M [00:03<00:00, 36.2MB/s] 83%|########3 | 141M/170M [00:03<00:00, 39.2MB/s] 86%|########5 | 145M/170M [00:03<00:00, 39.5MB/s] 88%|########7 | 149M/170M [00:03<00:00, 37.0MB/s] 90%|######### | 154M/170M [00:04<00:00, 38.9MB/s] 93%|#########2| 157M/170M [00:04<00:00, 36.0MB/s] 95%|#########5| 161M/170M [00:04<00:00, 37.3MB/s] 98%|#########7| 166M/170M [00:04<00:00, 39.4MB/s] 100%|#########9| 170M/170M [00:04<00:00, 38.5MB/s] 100%|##########| 170M/170M [00:04<00:00, 39.2MB/s] /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3878: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). for i in range(dim) /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). for g in grid_sizes /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). for g in grid_sizes /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:73: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). A = Ax4 // 4 /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). C = AxC // A /usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:156: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:158: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/transform.py:293: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). for s, s_orig in zip(new_size, original_size) /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/roi_heads.py:387: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) .. GENERATED FROM PYTHON SOURCE LINES 103-105 Download a test image and pre-process ------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 105-116 .. code-block:: default img_url = ( "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg" ) img_path = download_testdata(img_url, "test_street_small.jpg", module="data") img = cv2.imread(img_path).astype("float32") img = cv2.resize(img, (in_size, in_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.transpose(img / 255.0, [2, 0, 1]) img = np.expand_dims(img, axis=0) .. GENERATED FROM PYTHON SOURCE LINES 117-119 Import the graph to Relay ------------------------- .. GENERATED FROM PYTHON SOURCE LINES 119-123 .. code-block:: default input_name = "input0" shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(script_module, shape_list) .. rst-class:: sphx-glr-script-out .. code-block:: none /workspace/python/tvm/relay/build_module.py:348: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function) DeprecationWarning, /workspace/python/tvm/relay/frontend/pytorch.py:430: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations inputs[3], lambda ret: ret.astype(np.int).item(0) .. GENERATED FROM PYTHON SOURCE LINES 124-130 Compile with Relay VM --------------------- Note: Currently only CPU target is supported. For x86 target, it is highly recommended to build TVM with Intel MKL and Intel OpenMP to get best performance, due to the existence of large dense operator in torchvision rcnn models. .. GENERATED FROM PYTHON SOURCE LINES 130-139 .. code-block:: default # Add "-libs=mkl" to get best performance on x86 target. # For x86 machine supports AVX512, the complete target is # "llvm -mcpu=skylake-avx512 -libs=mkl" target = "llvm" with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) .. GENERATED FROM PYTHON SOURCE LINES 140-142 Inference with Relay VM ----------------------- .. GENERATED FROM PYTHON SOURCE LINES 142-147 .. code-block:: default dev = tvm.cpu() vm = VirtualMachine(vm_exec, dev) vm.set_input("main", **{input_name: img}) tvm_res = vm.run() .. GENERATED FROM PYTHON SOURCE LINES 148-150 Get boxes with score larger than 0.9 ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 150-160 .. code-block:: default score_threshold = 0.9 boxes = tvm_res[0].numpy().tolist() valid_boxes = [] for i, score in enumerate(tvm_res[1].numpy().tolist()): if score > score_threshold: valid_boxes.append(boxes[i]) else: break print("Get {} valid boxes".format(len(valid_boxes))) .. rst-class:: sphx-glr-script-out .. code-block:: none Get 9 valid boxes .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 56.672 seconds) .. _sphx_glr_download_how_to_deploy_models_deploy_object_detection_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: deploy_object_detection_pytorch.py <deploy_object_detection_pytorch.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: deploy_object_detection_pytorch.ipynb <deploy_object_detection_pytorch.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_