.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY .. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE .. CHANGES, EDIT THE SOURCE PYTHON FILE: .. "how_to/deploy_models/deploy_model_on_rasp.py" .. only:: html .. note:: :class: sphx-glr-download-link-note This tutorial can be used interactively with Google Colab! You can also click :ref:`here ` to run the Jupyter notebook locally. .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg :align: center :target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/7c392f39b90d93406ef30c6185c5686c/deploy_model_on_rasp.ipynb :width: 300px .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_deploy_models_deploy_model_on_rasp.py: .. _tutorial-deploy-model-on-rasp: Deploy the Pretrained Model on Raspberry Pi =========================================== **Author**: `Ziheng Jiang `_, `Hiroyuki Makino `_ This is an example of using Relay to compile a ResNet model and deploy it on Raspberry Pi. .. GENERATED FROM PYTHON SOURCE LINES 28-36 .. code-block:: default import tvm from tvm import te import tvm.relay as relay from tvm import rpc from tvm.contrib import utils, graph_executor as runtime from tvm.contrib.download import download_testdata .. GENERATED FROM PYTHON SOURCE LINES 37-74 .. _build-tvm-runtime-on-device: Build TVM Runtime on Device --------------------------- The first step is to build the TVM runtime on the remote device. .. note:: All instructions in both this section and next section should be executed on the target device, e.g. Raspberry Pi. And we assume it has Linux running. Since we do compilation on local machine, the remote device is only used for running the generated code. We only need to build tvm runtime on the remote device. .. code-block:: bash git clone --recursive https://github.com/apache/tvm tvm cd tvm mkdir build cp cmake/config.cmake build cd build cmake .. make runtime -j4 After building runtime successfully, we need to set environment varibles in :code:`~/.bashrc` file. We can edit :code:`~/.bashrc` using :code:`vi ~/.bashrc` and add the line below (Assuming your TVM directory is in :code:`~/tvm`): .. code-block:: bash export PYTHONPATH=$PYTHONPATH:~/tvm/python To update the environment variables, execute :code:`source ~/.bashrc`. .. GENERATED FROM PYTHON SOURCE LINES 76-92 Set Up RPC Server on Device --------------------------- To start an RPC server, run the following command on your remote device (Which is Raspberry Pi in our example). .. code-block:: bash python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090 If you see the line below, it means the RPC server started successfully on your device. .. code-block:: bash INFO:root:RPCServer: bind to 0.0.0.0:9090 .. GENERATED FROM PYTHON SOURCE LINES 94-99 Prepare the Pre-trained Model ----------------------------- Back to the host machine, which should have a full TVM installed (with LLVM). We will use pre-trained model from torchvision .. GENERATED FROM PYTHON SOURCE LINES 99-115 .. code-block:: default import torch import torchvision from PIL import Image import numpy as np # one line to get the model model_name = "resnet18" model = getattr(torchvision.models, model_name)(pretrained=True) model = model.eval() # We grab the TorchScripted model via tracing input_shape = [1, 3, 224, 224] input_data = torch.randn(input_shape) scripted_model = torch.jit.trace(model, input_data).eval() .. rst-class:: sphx-glr-script-out .. code-block:: none /venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) .. GENERATED FROM PYTHON SOURCE LINES 116-118 In order to test our model, here we download an image of cat and transform its format. .. GENERATED FROM PYTHON SOURCE LINES 118-134 .. code-block:: default img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" img_name = "cat.png" img_path = download_testdata(img_url, img_name, module="data") image = Image.open(img_path).resize((224, 224)) def transform_image(image): image = np.array(image) - np.array([123.0, 117.0, 104.0]) image /= np.array([58.395, 57.12, 57.375]) image = image.transpose((2, 0, 1)) image = image[np.newaxis, :] return image x = transform_image(image) .. GENERATED FROM PYTHON SOURCE LINES 135-137 synset is used to transform the label from number of ImageNet class to the word human can understand. .. GENERATED FROM PYTHON SOURCE LINES 137-150 .. code-block:: default synset_url = "".join( [ "https://gist.githubusercontent.com/zhreshold/", "4d0b62f3d01426887599d4f7ede23ee5/raw/", "596b27d23537e5a1b5751d2b0481ef172f58b539/", "imagenet1000_clsid_to_human.txt", ] ) synset_name = "imagenet1000_clsid_to_human.txt" synset_path = download_testdata(synset_url, synset_name, module="data") with open(synset_path) as f: synset = eval(f.read()) .. GENERATED FROM PYTHON SOURCE LINES 151-153 Now we would like to port the PyTorch model to a portable computational graph. It's as easy as several lines. .. GENERATED FROM PYTHON SOURCE LINES 153-161 .. code-block:: default input_name = "input0" shape_list = [(input_name, x.shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) # we want a probability so add a softmax operator func = mod["main"] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) .. rst-class:: sphx-glr-script-out .. code-block:: none /workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead. return LooseVersion(torch_ver) > ver /venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead. other = LooseVersion(other) .. GENERATED FROM PYTHON SOURCE LINES 162-163 Here are some basic data workload configurations. .. GENERATED FROM PYTHON SOURCE LINES 163-168 .. code-block:: default batch_size = 1 num_classes = 1000 image_shape = (3, 224, 224) data_shape = (batch_size,) + image_shape .. GENERATED FROM PYTHON SOURCE LINES 169-178 Compile The Graph ----------------- To compile the graph, we call the :py:func:`relay.build` function with the graph configuration and parameters. However, You cannot to deploy a x86 program on a device with ARM instruction set. It means Relay also needs to know the compilation option of target device, apart from arguments :code:`net` and :code:`params` to specify the deep learning workload. Actually, the option matters, different option will lead to very different performance. .. GENERATED FROM PYTHON SOURCE LINES 180-184 If we run the example on our x86 server for demonstration, we can simply set it as :code:`llvm`. If running it on the Raspberry Pi, we need to specify its instruction set. Set :code:`local_demo` to False if you want to run this tutorial with a real device. .. GENERATED FROM PYTHON SOURCE LINES 184-206 .. code-block:: default local_demo = True if local_demo: target = tvm.target.Target("llvm") else: target = tvm.target.arm_cpu("rasp3b") # The above line is a simple form of # target = tvm.target.Target('llvm -device=arm_cpu -model=bcm2837 -mtriple=armv7l-linux-gnueabihf -mattr=+neon') with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, target, params=params) # After `relay.build`, you will get three return values: graph, # library and the new parameter, since we do some optimization that will # change the parameters but keep the result of model as the same. # Save the library at local temporary directory. tmp = utils.tempdir() lib_fname = tmp.relpath("net.tar") lib.export_library(lib_fname) .. rst-class:: sphx-glr-script-out .. code-block:: none /workspace/python/tvm/relay/build_module.py:345: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function) warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 207-211 Deploy the Model Remotely by RPC -------------------------------- With RPC, you can deploy the model remotely from your host machine to the remote device. .. GENERATED FROM PYTHON SOURCE LINES 211-237 .. code-block:: default # obtain an RPC session from remote device. if local_demo: remote = rpc.LocalSession() else: # The following is my environment, change this to the IP address of your target device host = "10.77.1.162" port = 9090 remote = rpc.connect(host, port) # upload the library to remote device and load it remote.upload(lib_fname) rlib = remote.load_module("net.tar") # create the remote runtime module dev = remote.cpu(0) module = runtime.GraphModule(rlib["default"](dev)) # set input data module.set_input(input_name, tvm.nd.array(x.astype("float32"))) # run module.run() # get output out = module.get_output(0) # get top1 result top1 = np.argmax(out.numpy()) print("TVM prediction top-1: {}".format(synset[top1])) .. rst-class:: sphx-glr-script-out .. code-block:: none TVM prediction top-1: tabby, tabby cat .. _sphx_glr_download_how_to_deploy_models_deploy_model_on_rasp.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: deploy_model_on_rasp.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: deploy_model_on_rasp.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_