Auto-scheduling a convolution layer for GPU

Author: Lianmin Zheng, Chengfan Jia

Different from the existing autotvm which relies on manual templates to define the search space, the auto-scheduler does not require any templates. Users only need to write the computation declaration without any schedule commands or templates. The auto-scheduler can automatically generate a large search space and find a good schedule in the space.

We use a convolution layer as an example in this tutorial.

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

Define the computation

To begin with, let us define the computation of a convolution layer. The function should return the list of input/output tensors. From these tensors, the auto-scheduler can get the whole computational graph.

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    bias = te.placeholder((1, CO, 1, 1), name="bias")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    out = topi.nn.relu(conv + bias)
    return [data, kernel, bias, out]

Create the search task

We then create a search task for the last convolution layer in the resnet.

target = tvm.target.Target("cuda")

# Use the last layer in ResNet-50
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.create_task(conv2d_layer, (N, H, W, CO, CI, KH, KW, strides, padding), target)

# Inspect the computational graph
print(task.compute_dag)

Out:

data = PLACEHOLDER [1, 512, 7, 7]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 8)) && (i3 >= 1)) && (i3 < 8)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
kernel = PLACEHOLDER [512, 512, 3, 3]
compute(nn, ff, yy, xx) += (pad_temp[nn, rc, (yy + ry), (xx + rx)]*kernel[ff, rc, ry, rx])
bias = PLACEHOLDER [1, 512, 1, 1]
T_add(ax0, ax1, ax2, ax3) = (compute[ax0, ax1, ax2, ax3] + bias[ax0, ax1, 0, 0])
compute(i0, i1, i2, i3) = max(T_add[i0, i1, i2, i3], 0f)

Next, we set parameters for the auto-scheduler. These parameters mainly specify how we do the measurement during the search and auto-tuning.

  • measure_ctx launches a different process for measurement. This provides an isolation. It can protect the master process from GPU crashes happended during measurement and avoid other runtime conflicts.

  • min_repeat_ms defines the minimum duration of one “repeat” in every measurement. This can warmup the GPU, which is necessary to get accurate measurement results. Typically, we recommend a value > 300 ms.

  • num_measure_trials is the number of measurement trials we can use during the search. We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a good value for the search to converge. You can do more trials according to your time budget.

  • In addition, we use RecordToFile to dump measurement records into a file conv2d.json. The measurement records can be used to query the history best, resume the search, and do more analyses later.

  • see auto_scheduler.TuningOptions, auto_scheduler.LocalRPCMeasureContext for more parameters.

measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile("conv2d.json")],
)

Out:

Get devices for measurement successfully!

Check correctness and evaluate performance

We build the binary and check its correctness and performance.

func = tvm.build(sch, args, target)

# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)

ctx = tvm.gpu()
data_tvm = tvm.nd.array(data_np, ctx=ctx)
weight_tvm = tvm.nd.array(weight_np, ctx=ctx)
bias_tvm = tvm.nd.array(bias_np, ctx=ctx)
out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
func(data_tvm, weight_tvm, bias_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)

# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)

Out:

Execution time of this operator: 0.354 ms

Using the record file

During the search, all measuremnt records are dumpped into the record file “conv2d.json”. The measurement records can be used to re-apply search results, resume the search, and perform other analyses.

Here is an example where we load the best schedule from a file, print the equivalent python schedule API, and build the binary again.

# Load the measuremnt record for the best schedule
inp, res = auto_scheduler.load_best("conv2d.json", task.workload_key)

# Print equivalent python schedule API. This can be used for debugging and
# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
print(task.compute_dag.print_python_code_from_state(inp.state))

# Rebuild the binary. This shows how you can apply the best schedule from a
# log file without reruning the search again.
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args, target)

Out:

Equivalent python schedule:
pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, compute_ry, compute_rx = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
T_add_ax0, T_add_ax1, T_add_ax2, T_add_ax3 = tuple(T_add.op.axis) + tuple(T_add.op.reduce_axis)
compute_i0, compute_i1, compute_i2, compute_i3 = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
s[T_add].compute_inline()
compute_nn_o_i, compute_nn_i = s[compute].split(compute_nn, factor=1)
compute_nn_o_o_i, compute_nn_o_i = s[compute].split(compute_nn_o_i, factor=1)
compute_nn_o_o_o_i, compute_nn_o_o_i = s[compute].split(compute_nn_o_o_i, factor=1)
compute_nn_o_o_o_o, compute_nn_o_o_o_i = s[compute].split(compute_nn_o_o_o_i, factor=1)
compute_ff_o_i, compute_ff_i = s[compute].split(compute_ff, factor=1)
compute_ff_o_o_i, compute_ff_o_i = s[compute].split(compute_ff_o_i, factor=1)
compute_ff_o_o_o_i, compute_ff_o_o_i = s[compute].split(compute_ff_o_o_i, factor=16)
compute_ff_o_o_o_o, compute_ff_o_o_o_i = s[compute].split(compute_ff_o_o_o_i, factor=1)
compute_yy_o_i, compute_yy_i = s[compute].split(compute_yy, factor=1)
compute_yy_o_o_i, compute_yy_o_i = s[compute].split(compute_yy_o_i, factor=1)
compute_yy_o_o_o_i, compute_yy_o_o_i = s[compute].split(compute_yy_o_o_i, factor=7)
compute_yy_o_o_o_o, compute_yy_o_o_o_i = s[compute].split(compute_yy_o_o_o_i, factor=1)
compute_xx_o_i, compute_xx_i = s[compute].split(compute_xx, factor=1)
compute_xx_o_o_i, compute_xx_o_i = s[compute].split(compute_xx_o_i, factor=7)
compute_xx_o_o_o_i, compute_xx_o_o_i = s[compute].split(compute_xx_o_o_i, factor=1)
compute_xx_o_o_o_o, compute_xx_o_o_o_i = s[compute].split(compute_xx_o_o_o_i, factor=1)
compute_rc_o_i, compute_rc_i = s[compute].split(compute_rc, factor=8)
compute_rc_o_o, compute_rc_o_i = s[compute].split(compute_rc_o_i, factor=2)
compute_ry_o_i, compute_ry_i = s[compute].split(compute_ry, factor=3)
compute_ry_o_o, compute_ry_o_i = s[compute].split(compute_ry_o_i, factor=1)
compute_rx_o_i, compute_rx_i = s[compute].split(compute_rx, factor=1)
compute_rx_o_o, compute_rx_o_i = s[compute].split(compute_rx_o_i, factor=3)
s[compute].reorder(compute_nn_o_o_o_o, compute_ff_o_o_o_o, compute_yy_o_o_o_o, compute_xx_o_o_o_o, compute_nn_o_o_o_i, compute_ff_o_o_o_i, compute_yy_o_o_o_i, compute_xx_o_o_o_i, compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i, compute_rc_o_o, compute_ry_o_o, compute_rx_o_o, compute_rc_o_i, compute_ry_o_i, compute_rx_o_i, compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i, compute_rc_i, compute_ry_i, compute_rx_i, compute_nn_i, compute_ff_i, compute_yy_i, compute_xx_i)
compute_i0_o_i, compute_i0_i = s[compute].split(compute_i0, factor=1)
compute_i0_o_o_i, compute_i0_o_i = s[compute].split(compute_i0_o_i, factor=1)
compute_i0_o_o_o, compute_i0_o_o_i = s[compute].split(compute_i0_o_o_i, factor=1)
compute_i1_o_i, compute_i1_i = s[compute].split(compute_i1, factor=1)
compute_i1_o_o_i, compute_i1_o_i = s[compute].split(compute_i1_o_i, factor=16)
compute_i1_o_o_o, compute_i1_o_o_i = s[compute].split(compute_i1_o_o_i, factor=1)
compute_i2_o_i, compute_i2_i = s[compute].split(compute_i2, factor=1)
compute_i2_o_o_i, compute_i2_o_i = s[compute].split(compute_i2_o_i, factor=7)
compute_i2_o_o_o, compute_i2_o_o_i = s[compute].split(compute_i2_o_o_i, factor=1)
compute_i3_o_i, compute_i3_i = s[compute].split(compute_i3, factor=7)
compute_i3_o_o_i, compute_i3_o_i = s[compute].split(compute_i3_o_i, factor=1)
compute_i3_o_o_o, compute_i3_o_o_i = s[compute].split(compute_i3_o_o_i, factor=1)
s[compute].reorder(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o, compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i, compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i, compute_i0_i, compute_i1_i, compute_i2_i, compute_i3_i)
s[compute].compute_at(s[compute], compute_i3_o_i)
kernel_shared = s.cache_read(kernel, "shared", [compute])
kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3 = tuple(kernel_shared.op.axis)
s[kernel_shared].compute_at(s[compute], compute_rx_o_o)
pad_temp_shared = s.cache_read(pad_temp, "shared", [compute])
pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
s[pad_temp_shared].compute_at(s[compute], compute_rx_o_o)
s[pad_temp].compute_inline()
compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused = s[compute].fuse(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o)
s[compute].bind(compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused, te.thread_axis("blockIdx.x"))
compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused = s[compute].fuse(compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i)
s[compute].bind(compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused, te.thread_axis("vthread"))
compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused = s[compute].fuse(compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i)
s[compute].bind(compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused, te.thread_axis("threadIdx.x"))
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[kernel_shared].fuse(kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=2)
s[kernel_shared].vectorize(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=112)
s[kernel_shared].bind(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=9)
s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=112)
s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
s[compute].pragma(compute_nn_o_o_o_o, "auto_unroll_max_step", 1024)
s[compute].pragma(compute_nn_o_o_o_o, "unroll_explicit", True)

A more complicated example is to resume the search. In this case, we need to create the search policy and cost model by ourselves and resume the status of search policy and cost model with the log file. In the example below we resume the status and do more 5 trials.

log_file = "conv2d.json"
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
    task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=5,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)

# Kill the measurement process
del measure_ctx

Out:

Get devices for measurement successfully!

Total running time of the script: ( 3 minutes 13.532 seconds)

Gallery generated by Sphinx-Gallery