Auto-tuning a Convolutional Network for Mobile GPU

Author: Lianmin Zheng, Eddie Yan

Auto-tuning for a specific device is critical for getting the best performance. This is a tutorial about how to tune a whole convolutional network.

The operator implementation for Mobile GPU in TVM is written in template form. The template has many tunable knobs (tile factor, vectorization, unrolling, etc). We will tune all convolution, depthwise convolution and dense operators in the neural network. After tuning, we produce a log file which stores the best knob values for all required operators. When the TVM compiler compiles these operators, it will query this log file to get the best knob values.

We also released pre-tuned parameters for some arm devices. You can go to Mobile GPU Benchmark to see the results.

Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a if __name__ == "__main__": block.

Install dependencies

To use the autotvm package in tvm, we need to install some extra dependencies. (change “3” to “2” if you use python2):

pip3 install --user psutil xgboost tornado cloudpickle

To make TVM run faster during tuning, it is recommended to use cython as FFI of tvm. In the root directory of tvm, execute (change “3” to “2” if you use python2):

pip3 install --user cython
sudo make cython3

Now return to python code. Import packages.

import os

import numpy as np

import tvm
from tvm import relay, autotvm
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.utils import tempdir
import tvm.contrib.graph_executor as runtime

Define network

First we need to define the network in relay frontend API. We can load some pre-defined network from relay.testing. We can also load models from MXNet, ONNX and TensorFlow.

def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

    if "resnet" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif "vgg" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.vgg.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif name == "mobilenet":
        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "squeezenet_v1.1":
        mod, params = relay.testing.squeezenet.get_workload(
            batch_size=batch_size, version="1.1", dtype=dtype
        )
    elif name == "inception_v3":
        input_shape = (batch_size, 3, 299, 299)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "mxnet":
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model

        block = get_model("resnet18_v1", pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(
            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
        )
        mod = tvm.IRModule.from_expr(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params, input_shape, output_shape

Start RPC Tracker

TVM uses RPC session to communicate with ARM boards. During tuning, the tuner will send the generated code to the board and measure the speed of code on the board.

To scale up the tuning, TVM uses RPC Tracker to manage distributed devices. The RPC Tracker is a centralized controller node. We can register all devices to the tracker. For example, if we have 10 phones, we can register all of them to the tracker, and run 10 measurements in parallel, accelerating the tuning process.

To start an RPC tracker, run this command on the host machine. The tracker is required during the whole tuning process, so we need to open a new terminal for this command:

python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190

The expected output is

INFO:RPCTracker:bind to 0.0.0.0:9190

Register Devices to RPC Tracker

Now we can register our devices to the tracker. The first step is to build the TVM runtime for the ARM devices.

  • For Linux: Follow this section Build TVM Runtime on Device to build the TVM runtime on the device. Then register the device to tracker by

    python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=rk3399
    

    (replace [HOST_IP] with the IP address of your host machine)

  • For Android: Follow this readme page to install TVM RPC APK on the android device. Make sure you can pass the android RPC test. Then you have already registered your device. During tuning, you have to go to developer option and enable “Keep screen awake during changing” and charge your phone to make it stable.

After registering devices, we can confirm it by querying rpc_tracker

python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190

For example, if we have 2 Huawei mate10 pro, 11 Raspberry Pi 3B and 2 rk3399, the output can be

Queue Status
----------------------------------
key          total  free  pending
----------------------------------
mate10pro    2      2     0
rk3399       2      2     0
rpi3b        11     11    0
----------------------------------

You can register multiple devices to the tracker to accelerate the measurement in tuning.

Set Tuning Options

Before tuning, we should apply some configurations. Here I use an RK3399 board as example. In your setting, you should modify the target and device_key accordingly. set use_android to True if you use android phone.

#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target host is used for cross compilation. You can query it by :code:`gcc -v` on your device.
target = tvm.target.Target("opencl -device=mali", host="llvm -mtriple=aarch64-linux-gnu")

# Also replace this with the device key in your tracker
device_key = "rk3399"

# Set this to True if you use android phone
use_android = False

#### TUNING OPTION ####
network = "resnet-18"
log_file = "%s.%s.log" % (device_key, network)
dtype = "float32"

tuning_option = {
    "log_filename": log_file,
    "tuner": "xgb",
    "n_trial": 1000,
    "early_stopping": 450,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(build_func="ndk" if use_android else "default"),
        runner=autotvm.RPCRunner(
            device_key,
            host="127.0.0.1",
            port=9190,
            number=10,
            timeout=5,
        ),
    ),
}

Note

How to set tuning options

In general, the default values provided here work well. If you have enough time budget, you can set n_trial, early_stopping larger, which makes the tuning run longer. If your device runs very slow or your conv2d operators have many GFLOPs, considering to set timeout larger.

Begin Tuning

Now we can extract tuning tasks from the network and begin tuning. Here, we provide a simple utility function to tune a list of tasks. This function is just an initial implementation which tunes them in sequential order. We will introduce a more sophisticated tuning scheduler in the future.

# You can skip the implementation of this function for this tutorial.
def tune_tasks(
    tasks,
    measure_option,
    tuner="xgb",
    n_trial=1000,
    early_stopping=None,
    log_filename="tuning.log",
    use_transfer_learning=True,
):
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))

        # create tuner
        if tuner == "xgb":
            tuner_obj = XGBTuner(tsk, loss_type="reg")
        elif tuner == "xgb_knob":
            tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob")
        elif tuner == "xgb_itervar":
            tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar")
        elif tuner == "xgb_curve":
            tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve")
        elif tuner == "xgb_rank":
            tuner_obj = XGBTuner(tsk, loss_type="rank")
        elif tuner == "xgb_rank_knob":
            tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob")
        elif tuner == "xgb_rank_itervar":
            tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar")
        elif tuner == "xgb_rank_curve":
            tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve")
        elif tuner == "xgb_rank_binary":
            tuner_obj = XGBTuner(tsk, loss_type="rank-binary")
        elif tuner == "xgb_rank_binary_knob":
            tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob")
        elif tuner == "xgb_rank_binary_itervar":
            tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar")
        elif tuner == "xgb_rank_binary_curve":
            tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve")
        elif tuner == "ga":
            tuner_obj = GATuner(tsk, pop_size=50)
        elif tuner == "random":
            tuner_obj = RandomTuner(tsk)
        elif tuner == "gridsearch":
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(
            n_trial=tsk_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                autotvm.callback.log_to_file(tmp_log_file),
            ],
        )

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)

Finally, we launch tuning jobs and evaluate the end-to-end performance.

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape, _ = get_network(network, batch_size=1)
    tasks = autotvm.task.extract_from_program(
        mod["main"],
        target=target,
        params=params,
        ops=(relay.op.get("nn.conv2d"),),
    )

    # run tuning tasks
    print("Tuning...")
    tune_tasks(tasks, **tuning_opt)

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build_module.build(mod, target=target, params=params)
        # export library
        tmp = tempdir()
        if use_android:
            from tvm.contrib import ndk

            filename = "net.so"
            lib.export_library(tmp.relpath(filename), ndk.create_shared)
        else:
            filename = "net.tar"
            lib.export_library(tmp.relpath(filename))

        # upload module to device
        print("Upload...")
        remote = autotvm.measure.request_remote(device_key, "127.0.0.1", 9190, timeout=10000)
        remote.upload(tmp.relpath(filename))
        rlib = remote.load_module(filename)

        # upload parameters to device
        dev = remote.device(str(target), 0)
        module = runtime.GraphModule(rlib["default"](dev))
        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
        module.set_input("data", data_tvm)

        # evaluate
        print("Evaluate inference time cost...")
        print(module.benchmark(dev, number=1, repeat=30))


# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.

# tune_and_evaluate(tuning_option)

Sample Output

The tuning needs to compile many programs and extract feature from them. So a high performance CPU is recommended. One sample output is listed below. It takes about 3 hours on a 32T AMD Ryzen Threadripper.

Extract tasks...
Tuning...
[Task  1/17]  Current/Best:   25.30/  39.12 GFLOPS | Progress: (992/1000) | 751.22 s Done.
[Task  2/17]  Current/Best:   40.70/  45.50 GFLOPS | Progress: (736/1000) | 545.46 s Done.
[Task  3/17]  Current/Best:   38.83/  42.35 GFLOPS | Progress: (992/1000) | 1549.85 s Done.
[Task  4/17]  Current/Best:   23.31/  31.02 GFLOPS | Progress: (640/1000) | 1059.31 s Done.
[Task  5/17]  Current/Best:    0.06/   2.34 GFLOPS | Progress: (544/1000) | 305.45 s Done.
[Task  6/17]  Current/Best:   10.97/  17.20 GFLOPS | Progress: (992/1000) | 1050.00 s Done.
[Task  7/17]  Current/Best:    8.98/  10.94 GFLOPS | Progress: (928/1000) | 421.36 s Done.
[Task  8/17]  Current/Best:    4.48/  14.86 GFLOPS | Progress: (704/1000) | 582.60 s Done.
[Task  9/17]  Current/Best:   10.30/  25.99 GFLOPS | Progress: (864/1000) | 899.85 s Done.
[Task 10/17]  Current/Best:   11.73/  12.52 GFLOPS | Progress: (608/1000) | 304.85 s Done.
[Task 11/17]  Current/Best:   15.26/  18.68 GFLOPS | Progress: (800/1000) | 747.52 s Done.
[Task 12/17]  Current/Best:   17.48/  26.71 GFLOPS | Progress: (1000/1000) | 1166.40 s Done.
[Task 13/17]  Current/Best:    0.96/  11.43 GFLOPS | Progress: (960/1000) | 611.65 s Done.
[Task 14/17]  Current/Best:   17.88/  20.22 GFLOPS | Progress: (672/1000) | 670.29 s Done.
[Task 15/17]  Current/Best:   11.62/  13.98 GFLOPS | Progress: (736/1000) | 449.25 s Done.
[Task 16/17]  Current/Best:   19.90/  23.83 GFLOPS | Progress: (608/1000) | 708.64 s Done.
[Task 17/17]  Current/Best:   17.98/  22.75 GFLOPS | Progress: (736/1000) | 1122.60 s Done.
Compile...
Upload...
Evaluate inference time cost...
Mean inference time (std dev): 128.05 ms (7.74 ms)

Note

Experiencing Difficulties?

The auto tuning module is error-prone. If you always see ” 0.00/ 0.00 GFLOPS”, then there must be something wrong.

First, make sure you set the correct configuration of your device. Then, you can print debug information by adding these lines in the beginning of the script. It will print every measurement result, where you can find useful error messages.

import logging
logging.getLogger('autotvm').setLevel(logging.DEBUG)

Finally, always feel free to ask our community for help on https://discuss.tvm.apache.org

Gallery generated by Sphinx-Gallery