.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/deploy_models/deploy_object_detection_pytorch.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_how_to_deploy_models_deploy_object_detection_pytorch.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_how_to_deploy_models_deploy_object_detection_pytorch.py:


Compile PyTorch Object Detection Models
=======================================
This article is an introductory tutorial to deploy PyTorch object
detection models with Relay VM.

For us to begin with, PyTorch should be installed.
TorchVision is also required since we will be using it as our model zoo.

A quick solution is to install via pip

.. code-block:: bash

    pip install torch==1.7.0
    pip install torchvision==0.8.1

or please refer to official site
https://pytorch.org/get-started/locally/

PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.

Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.

.. GENERATED FROM PYTHON SOURCE LINES 42-57

.. code-block:: default



    import tvm
    from tvm import relay
    from tvm import relay
    from tvm.runtime.vm import VirtualMachine
    from tvm.contrib.download import download_testdata

    import numpy as np
    import cv2

    # PyTorch imports
    import torch
    import torchvision








.. GENERATED FROM PYTHON SOURCE LINES 63-65

Load pre-trained maskrcnn from torchvision and do tracing
---------------------------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 65-102

.. code-block:: default

    in_size = 300

    input_shape = (1, 3, in_size, in_size)


    def do_trace(model, inp):
        model_trace = torch.jit.trace(model, inp)
        model_trace.eval()
        return model_trace


    def dict_to_tuple(out_dict):
        if "masks" in out_dict.keys():
            return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


    class TraceWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, inp):
            out = self.model(inp)
            return dict_to_tuple(out[0])


    model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
    model = TraceWrapper(model_func(pretrained=True))

    model.eval()
    inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

    with torch.no_grad():
        out = model(inp)
        script_module = do_trace(model, inp)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

      0%|          | 0.00/170M [00:00<?, ?B/s]
      2%|1         | 2.94M/170M [00:00<00:05, 30.4MB/s]
      4%|4         | 7.18M/170M [00:00<00:04, 38.6MB/s]
      7%|7         | 12.3M/170M [00:00<00:03, 45.7MB/s]
     10%|#         | 17.0M/170M [00:00<00:03, 46.8MB/s]
     13%|#2        | 21.5M/170M [00:00<00:03, 46.5MB/s]
     15%|#5        | 25.9M/170M [00:00<00:03, 43.3MB/s]
     18%|#7        | 30.1M/170M [00:00<00:03, 40.9MB/s]
     20%|##        | 34.0M/170M [00:00<00:03, 37.9MB/s]
     22%|##2       | 37.7M/170M [00:01<00:03, 34.7MB/s]
     25%|##5       | 43.2M/170M [00:01<00:03, 40.9MB/s]
     28%|##7       | 47.5M/170M [00:01<00:03, 41.8MB/s]
     30%|###       | 51.5M/170M [00:01<00:03, 36.5MB/s]
     32%|###2      | 55.2M/170M [00:01<00:03, 35.4MB/s]
     35%|###4      | 58.7M/170M [00:01<00:03, 30.2MB/s]
     36%|###6      | 61.9M/170M [00:01<00:03, 31.0MB/s]
     39%|###9      | 66.7M/170M [00:01<00:03, 35.8MB/s]
     42%|####1     | 71.2M/170M [00:01<00:02, 38.9MB/s]
     44%|####4     | 75.1M/170M [00:02<00:02, 39.0MB/s]
     46%|####6     | 78.9M/170M [00:02<00:02, 36.0MB/s]
     49%|####8     | 83.1M/170M [00:02<00:02, 37.9MB/s]
     51%|#####1    | 87.2M/170M [00:02<00:02, 39.0MB/s]
     54%|#####3    | 91.5M/170M [00:02<00:02, 41.0MB/s]
     56%|#####6    | 95.7M/170M [00:02<00:01, 41.4MB/s]
     59%|#####8    | 100M/170M [00:02<00:01, 42.9MB/s] 
     61%|######1   | 104M/170M [00:02<00:01, 43.1MB/s]
     64%|######4   | 109M/170M [00:02<00:01, 44.8MB/s]
     67%|######6   | 113M/170M [00:03<00:01, 43.5MB/s]
     70%|######9   | 119M/170M [00:03<00:01, 47.8MB/s]
     73%|#######2  | 123M/170M [00:03<00:01, 47.8MB/s]
     75%|#######5  | 128M/170M [00:03<00:01, 43.5MB/s]
     78%|#######7  | 132M/170M [00:03<00:01, 33.6MB/s]
     80%|########  | 137M/170M [00:03<00:00, 36.2MB/s]
     83%|########3 | 141M/170M [00:03<00:00, 39.2MB/s]
     86%|########5 | 145M/170M [00:03<00:00, 39.5MB/s]
     88%|########7 | 149M/170M [00:03<00:00, 37.0MB/s]
     90%|######### | 154M/170M [00:04<00:00, 38.9MB/s]
     93%|#########2| 157M/170M [00:04<00:00, 36.0MB/s]
     95%|#########5| 161M/170M [00:04<00:00, 37.3MB/s]
     98%|#########7| 166M/170M [00:04<00:00, 39.4MB/s]
    100%|#########9| 170M/170M [00:04<00:00, 38.5MB/s]
    100%|##########| 170M/170M [00:04<00:00, 39.2MB/s]
    /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3878: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      for i in range(dim)
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
      for g in grid_sizes
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      for g in grid_sizes
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:73: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
      A = Ax4 // 4
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
      C = AxC // A
    /usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:156: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
    /usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:158: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/transform.py:293: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      for s, s_orig in zip(new_size, original_size)
    /usr/local/lib/python3.7/dist-packages/torchvision/models/detection/roi_heads.py:387: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
      return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)




.. GENERATED FROM PYTHON SOURCE LINES 103-105

Download a test image and pre-process
-------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 105-116

.. code-block:: default

    img_url = (
        "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
    )
    img_path = download_testdata(img_url, "test_street_small.jpg", module="data")

    img = cv2.imread(img_path).astype("float32")
    img = cv2.resize(img, (in_size, in_size))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img / 255.0, [2, 0, 1])
    img = np.expand_dims(img, axis=0)








.. GENERATED FROM PYTHON SOURCE LINES 117-119

Import the graph to Relay
-------------------------

.. GENERATED FROM PYTHON SOURCE LINES 119-123

.. code-block:: default

    input_name = "input0"
    shape_list = [(input_name, input_shape)]
    mod, params = relay.frontend.from_pytorch(script_module, shape_list)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /workspace/python/tvm/relay/build_module.py:348: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
      DeprecationWarning,
    /workspace/python/tvm/relay/frontend/pytorch.py:430: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
    Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
      inputs[3], lambda ret: ret.astype(np.int).item(0)




.. GENERATED FROM PYTHON SOURCE LINES 124-130

Compile with Relay VM
---------------------
Note: Currently only CPU target is supported. For x86 target, it is
highly recommended to build TVM with Intel MKL and Intel OpenMP to get
best performance, due to the existence of large dense operator in
torchvision rcnn models.

.. GENERATED FROM PYTHON SOURCE LINES 130-139

.. code-block:: default


    # Add "-libs=mkl" to get best performance on x86 target.
    # For x86 machine supports AVX512, the complete target is
    # "llvm -mcpu=skylake-avx512 -libs=mkl"
    target = "llvm"

    with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
        vm_exec = relay.vm.compile(mod, target=target, params=params)








.. GENERATED FROM PYTHON SOURCE LINES 140-142

Inference with Relay VM
-----------------------

.. GENERATED FROM PYTHON SOURCE LINES 142-147

.. code-block:: default

    dev = tvm.cpu()
    vm = VirtualMachine(vm_exec, dev)
    vm.set_input("main", **{input_name: img})
    tvm_res = vm.run()








.. GENERATED FROM PYTHON SOURCE LINES 148-150

Get boxes with score larger than 0.9
------------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 150-160

.. code-block:: default

    score_threshold = 0.9
    boxes = tvm_res[0].numpy().tolist()
    valid_boxes = []
    for i, score in enumerate(tvm_res[1].numpy().tolist()):
        if score > score_threshold:
            valid_boxes.append(boxes[i])
        else:
            break

    print("Get {} valid boxes".format(len(valid_boxes)))




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Get 9 valid boxes





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 2 minutes  56.672 seconds)


.. _sphx_glr_download_how_to_deploy_models_deploy_object_detection_pytorch.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: deploy_object_detection_pytorch.py <deploy_object_detection_pytorch.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: deploy_object_detection_pytorch.ipynb <deploy_object_detection_pytorch.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_