Auto-scheduling a Convolution Layer for GPU

Author: Lianmin Zheng, Chengfan Jia

This is a tutorial on how to use the auto-scheduler for GPUs.

Different from the template-based autotvm which relies on manual templates to define the search space, the auto-scheduler does not require any templates. Users only need to write the computation declaration without any schedule commands or templates. The auto-scheduler can automatically generate a large search space and find a good schedule in the space.

We use a convolution layer as an example in this tutorial.

Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a if __name__ == "__main__": block.

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

Define the computation

To begin with, let us define the computation of a convolution layer. The function should return the list of input/output tensors. From these tensors, the auto-scheduler can get the whole computational graph.

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    bias = te.placeholder((1, CO, 1, 1), name="bias")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    out = topi.nn.relu(conv + bias)
    return [data, kernel, bias, out]

Create the search task

We then create a search task for the last convolution layer in the resnet.

target = tvm.target.Target("cuda")

# Use the last layer in ResNet-50
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
Computational DAG:
data = PLACEHOLDER [1, 512, 7, 7]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 8)) && (i3 >= 1)) && (i3 < 8)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
kernel = PLACEHOLDER [512, 512, 3, 3]
conv2d_nchw(nn, ff, yy, xx) += (pad_temp[nn, rc, (yy + ry), (xx + rx)]*kernel[ff, rc, ry, rx])
bias = PLACEHOLDER [1, 512, 1, 1]
T_add(ax0, ax1, ax2, ax3) = (conv2d_nchw[ax0, ax1, ax2, ax3] + bias[ax0, ax1, 0, 0])
compute(i0, i1, i2, i3) = max(T_add[i0, i1, i2, i3], 0f)

Next, we set parameters for the auto-scheduler. These parameters mainly specify how we do the measurement during the search.

  • measure_ctx launches a different process for measurement to provide isolation. It can protect the main process from GPU crashes during measurement and avoid other runtime conflicts.

  • min_repeat_ms defines the minimum duration of one “repeat” in every measurement. This can warmup the GPU, which is necessary to get accurate measurement results. Typically, we recommend a value >= 300 ms.

  • num_measure_trials is the number of measurement trials we can use during the search. We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a good value for the search to converge. You can do more trials according to your time budget.

  • In addition, we use RecordToFile to dump measurement records into a file conv2d.json. The measurement records can be used to query the history best, resume the search, and do more analyses later.

  • see auto_scheduler.TuningOptions, auto_scheduler.LocalRPCMeasureContext for more parameters.

log_file = "conv2d.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)
Get devices for measurement successfully!

Check correctness and evaluate performance

We build the binary and check its correctness and performance.

func = tvm.build(sch, args, target)

# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)

dev = tvm.cuda()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
bias_tvm = tvm.nd.array(bias_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(data_tvm, weight_tvm, bias_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)
Execution time of this operator: 0.388 ms

Using the record file

During the search, all measurement records are dumped into the record file “conv2d.json”. The measurement records can be used to re-apply search results, resume the search, and perform other analyses.

Here is an example where we load the best schedule from a file, print the equivalent python schedule API and CUDA source code. They can be used for debugging and learning the behavior of the auto-scheduler.

print("Equivalent python schedule:")
print(task.print_best(log_file, print_mode="schedule"))

print("CUDA source code:")
print(task.print_best(log_file, print_mode="cuda"))
Equivalent python schedule:
pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
conv2d_nchw_nn, conv2d_nchw_ff, conv2d_nchw_yy, conv2d_nchw_xx, conv2d_nchw_rc, conv2d_nchw_ry, conv2d_nchw_rx = tuple(conv2d_nchw.op.axis) + tuple(conv2d_nchw.op.reduce_axis)
T_add_ax0, T_add_ax1, T_add_ax2, T_add_ax3 = tuple(T_add.op.axis) + tuple(T_add.op.reduce_axis)
compute_i0, compute_i1, compute_i2, compute_i3 = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
s[T_add].compute_inline()
conv2d_nchw_nn_o_i, conv2d_nchw_nn_i = s[conv2d_nchw].split(conv2d_nchw_nn, factor=1)
conv2d_nchw_nn_o_o_i, conv2d_nchw_nn_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_i, factor=1)
conv2d_nchw_nn_o_o_o_i, conv2d_nchw_nn_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_i, factor=1)
conv2d_nchw_nn_o_o_o_o, conv2d_nchw_nn_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_o_i, factor=1)
conv2d_nchw_ff_o_i, conv2d_nchw_ff_i = s[conv2d_nchw].split(conv2d_nchw_ff, factor=1)
conv2d_nchw_ff_o_o_i, conv2d_nchw_ff_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_i, factor=1)
conv2d_nchw_ff_o_o_o_i, conv2d_nchw_ff_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_i, factor=8)
conv2d_nchw_ff_o_o_o_o, conv2d_nchw_ff_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_o_i, factor=1)
conv2d_nchw_yy_o_i, conv2d_nchw_yy_i = s[conv2d_nchw].split(conv2d_nchw_yy, factor=7)
conv2d_nchw_yy_o_o_i, conv2d_nchw_yy_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_i, factor=1)
conv2d_nchw_yy_o_o_o_i, conv2d_nchw_yy_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_i, factor=1)
conv2d_nchw_yy_o_o_o_o, conv2d_nchw_yy_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_o_i, factor=1)
conv2d_nchw_xx_o_i, conv2d_nchw_xx_i = s[conv2d_nchw].split(conv2d_nchw_xx, factor=1)
conv2d_nchw_xx_o_o_i, conv2d_nchw_xx_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_i, factor=1)
conv2d_nchw_xx_o_o_o_i, conv2d_nchw_xx_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_i, factor=7)
conv2d_nchw_xx_o_o_o_o, conv2d_nchw_xx_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_o_i, factor=1)
conv2d_nchw_rc_o_i, conv2d_nchw_rc_i = s[conv2d_nchw].split(conv2d_nchw_rc, factor=4)
conv2d_nchw_rc_o_o, conv2d_nchw_rc_o_i = s[conv2d_nchw].split(conv2d_nchw_rc_o_i, factor=2)
conv2d_nchw_ry_o_i, conv2d_nchw_ry_i = s[conv2d_nchw].split(conv2d_nchw_ry, factor=1)
conv2d_nchw_ry_o_o, conv2d_nchw_ry_o_i = s[conv2d_nchw].split(conv2d_nchw_ry_o_i, factor=3)
conv2d_nchw_rx_o_i, conv2d_nchw_rx_i = s[conv2d_nchw].split(conv2d_nchw_rx, factor=3)
conv2d_nchw_rx_o_o, conv2d_nchw_rx_o_i = s[conv2d_nchw].split(conv2d_nchw_rx_o_i, factor=1)
s[conv2d_nchw].reorder(conv2d_nchw_nn_o_o_o_o, conv2d_nchw_ff_o_o_o_o, conv2d_nchw_yy_o_o_o_o, conv2d_nchw_xx_o_o_o_o, conv2d_nchw_nn_o_o_o_i, conv2d_nchw_ff_o_o_o_i, conv2d_nchw_yy_o_o_o_i, conv2d_nchw_xx_o_o_o_i, conv2d_nchw_nn_o_o_i, conv2d_nchw_ff_o_o_i, conv2d_nchw_yy_o_o_i, conv2d_nchw_xx_o_o_i, conv2d_nchw_rc_o_o, conv2d_nchw_ry_o_o, conv2d_nchw_rx_o_o, conv2d_nchw_rc_o_i, conv2d_nchw_ry_o_i, conv2d_nchw_rx_o_i, conv2d_nchw_nn_o_i, conv2d_nchw_ff_o_i, conv2d_nchw_yy_o_i, conv2d_nchw_xx_o_i, conv2d_nchw_rc_i, conv2d_nchw_ry_i, conv2d_nchw_rx_i, conv2d_nchw_nn_i, conv2d_nchw_ff_i, conv2d_nchw_yy_i, conv2d_nchw_xx_i)
compute_i0_o_i, compute_i0_i = s[compute].split(compute_i0, factor=1)
compute_i0_o_o_i, compute_i0_o_i = s[compute].split(compute_i0_o_i, factor=1)
compute_i0_o_o_o, compute_i0_o_o_i = s[compute].split(compute_i0_o_o_i, factor=1)
compute_i1_o_i, compute_i1_i = s[compute].split(compute_i1, factor=1)
compute_i1_o_o_i, compute_i1_o_i = s[compute].split(compute_i1_o_i, factor=8)
compute_i1_o_o_o, compute_i1_o_o_i = s[compute].split(compute_i1_o_o_i, factor=1)
compute_i2_o_i, compute_i2_i = s[compute].split(compute_i2, factor=7)
compute_i2_o_o_i, compute_i2_o_i = s[compute].split(compute_i2_o_i, factor=1)
compute_i2_o_o_o, compute_i2_o_o_i = s[compute].split(compute_i2_o_o_i, factor=1)
compute_i3_o_i, compute_i3_i = s[compute].split(compute_i3, factor=1)
compute_i3_o_o_i, compute_i3_o_i = s[compute].split(compute_i3_o_i, factor=7)
compute_i3_o_o_o, compute_i3_o_o_i = s[compute].split(compute_i3_o_o_i, factor=1)
s[compute].reorder(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o, compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i, compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i, compute_i0_i, compute_i1_i, compute_i2_i, compute_i3_i)
s[conv2d_nchw].compute_at(s[compute], compute_i3_o_i)
kernel_shared = s.cache_read(kernel, "shared", [conv2d_nchw])
kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3 = tuple(kernel_shared.op.axis)
s[kernel_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)
pad_temp_shared = s.cache_read(pad_temp, "shared", [conv2d_nchw])
pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
s[pad_temp_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)
s[pad_temp].compute_inline()
compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused = s[compute].fuse(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o)
s[compute].bind(compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused, te.thread_axis("blockIdx.x"))
compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused = s[compute].fuse(compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i)
s[compute].bind(compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused, te.thread_axis("vthread"))
compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused = s[compute].fuse(compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i)
s[compute].bind(compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused, te.thread_axis("threadIdx.x"))
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[kernel_shared].fuse(kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
s[kernel_shared].vectorize(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=56)
s[kernel_shared].bind(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=56)
s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "auto_unroll_max_step", 16)
s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "unroll_explicit", True)

CUDA source code:

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(56) default_function_kernel0(float* __restrict__ data, float* __restrict__ kernel, float* __restrict__ compute, float* __restrict__ bias) {
  float conv2d_nchw[7];
  __shared__ float pad_temp_shared[648];
  __shared__ float kernel_shared[576];
  conv2d_nchw[0] = 0.000000e+00f;
  conv2d_nchw[1] = 0.000000e+00f;
  conv2d_nchw[2] = 0.000000e+00f;
  conv2d_nchw[3] = 0.000000e+00f;
  conv2d_nchw[4] = 0.000000e+00f;
  conv2d_nchw[5] = 0.000000e+00f;
  conv2d_nchw[6] = 0.000000e+00f;
  for (int rc_outer_outer = 0; rc_outer_outer < 64; ++rc_outer_outer) {
    __syncthreads();
    pad_temp_shared[((int)threadIdx.x)] = ((((9 <= ((int)threadIdx.x)) && (1 <= (((int)threadIdx.x) % 9))) && ((((int)threadIdx.x) % 9) < 8)) ? data[((((rc_outer_outer * 392) + ((((int)threadIdx.x) / 9) * 7)) + (((int)threadIdx.x) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 56)] = (((((9 <= ((((int)threadIdx.x) + 56) % 81)) && (((((int)threadIdx.x) + 56) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 2) % 9))) && (((((int)threadIdx.x) + 2) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 56) / 81) * 49)) + ((((((int)threadIdx.x) + 56) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 2) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 112)] = (((((9 <= ((((int)threadIdx.x) + 31) % 81)) && (((((int)threadIdx.x) + 31) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 4) % 9))) && (((((int)threadIdx.x) + 4) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 112) / 81) * 49)) + ((((((int)threadIdx.x) + 31) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 4) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 168)] = ((((3 <= ((int)threadIdx.x)) && (1 <= ((((int)threadIdx.x) + 6) % 9))) && (((((int)threadIdx.x) + 6) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 168) / 81) * 49)) + (((((int)threadIdx.x) + 6) / 9) * 7)) + ((((int)threadIdx.x) + 6) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 224)] = (((((9 <= ((((int)threadIdx.x) + 62) % 81)) && (((((int)threadIdx.x) + 62) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 8) % 9))) && (((((int)threadIdx.x) + 8) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 224) / 81) * 49)) + ((((((int)threadIdx.x) + 62) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 8) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 280)] = (((((9 <= ((((int)threadIdx.x) + 37) % 81)) && (((((int)threadIdx.x) + 37) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 1) % 9))) && (((((int)threadIdx.x) + 1) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 280) / 81) * 49)) + ((((((int)threadIdx.x) + 37) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 1) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 336)] = (((1 <= ((((int)threadIdx.x) + 3) % 9)) && (((((int)threadIdx.x) + 3) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 336) / 81) * 49)) + (((((int)threadIdx.x) + 12) / 9) * 7)) + ((((int)threadIdx.x) + 3) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 392)] = (((((9 <= ((((int)threadIdx.x) + 68) % 81)) && (((((int)threadIdx.x) + 68) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 5) % 9))) && (((((int)threadIdx.x) + 5) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 392) / 81) * 49)) + ((((((int)threadIdx.x) + 68) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 5) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 448)] = (((((9 <= ((((int)threadIdx.x) + 43) % 81)) && (((((int)threadIdx.x) + 43) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 7) % 9))) && (((((int)threadIdx.x) + 7) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 448) / 81) * 49)) + ((((((int)threadIdx.x) + 43) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 7) % 9)) - 8)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 504)] = ((((((int)threadIdx.x) < 54) && (1 <= (((int)threadIdx.x) % 9))) && ((((int)threadIdx.x) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 504) / 81) * 49)) + ((((int)threadIdx.x) / 9) * 7)) + (((int)threadIdx.x) % 9)) + 6)] : 0.000000e+00f);
    pad_temp_shared[(((int)threadIdx.x) + 560)] = (((((9 <= ((((int)threadIdx.x) + 74) % 81)) && (((((int)threadIdx.x) + 74) % 81) < 72)) && (1 <= ((((int)threadIdx.x) + 2) % 9))) && (((((int)threadIdx.x) + 2) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 560) / 81) * 49)) + ((((((int)threadIdx.x) + 74) % 81) / 9) * 7)) + ((((int)threadIdx.x) + 2) % 9)) - 8)] : 0.000000e+00f);
    if (((int)threadIdx.x) < 32) {
      pad_temp_shared[(((int)threadIdx.x) + 616)] = ((((((int)threadIdx.x) < 23) && (1 <= ((((int)threadIdx.x) + 4) % 9))) && (((((int)threadIdx.x) + 4) % 9) < 8)) ? data[(((((rc_outer_outer * 392) + (((((int)threadIdx.x) + 616) / 81) * 49)) + (((((int)threadIdx.x) + 49) / 9) * 7)) + ((((int)threadIdx.x) + 4) % 9)) - 8)] : 0.000000e+00f);
    }
    kernel_shared[((int)threadIdx.x)] = kernel[(((((int)blockIdx.x) * 36864) + (rc_outer_outer * 72)) + ((int)threadIdx.x))];
    kernel_shared[(((int)threadIdx.x) + 56)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 56) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) + 56) % 72) / 3) * 3)) + ((((int)threadIdx.x) + 2) % 3))];
    kernel_shared[(((int)threadIdx.x) + 112)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 112) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) + 40) % 72) / 3) * 3)) + ((((int)threadIdx.x) + 1) % 3))];
    kernel_shared[(((int)threadIdx.x) + 168)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 168) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) / 3) + 8) % 24) * 3)) + (((int)threadIdx.x) % 3))];
    kernel_shared[(((int)threadIdx.x) + 224)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 224) / 72) * 4608)) + (rc_outer_outer * 72)) + (((((int)threadIdx.x) + 8) / 3) * 3)) + ((((int)threadIdx.x) + 2) % 3))];
    kernel_shared[(((int)threadIdx.x) + 280)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 280) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) + 64) % 72) / 3) * 3)) + ((((int)threadIdx.x) + 1) % 3))];
    kernel_shared[(((int)threadIdx.x) + 336)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 336) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) / 3) + 16) % 24) * 3)) + (((int)threadIdx.x) % 3))];
    kernel_shared[(((int)threadIdx.x) + 392)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 392) / 72) * 4608)) + (rc_outer_outer * 72)) + ((((((int)threadIdx.x) + 32) % 72) / 3) * 3)) + ((((int)threadIdx.x) + 2) % 3))];
    kernel_shared[(((int)threadIdx.x) + 448)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 448) / 72) * 4608)) + (rc_outer_outer * 72)) + (((((int)threadIdx.x) + 16) / 3) * 3)) + ((((int)threadIdx.x) + 1) % 3))];
    kernel_shared[(((int)threadIdx.x) + 504)] = kernel[((((((int)blockIdx.x) * 36864) + (rc_outer_outer * 72)) + ((int)threadIdx.x)) + 32256)];
    if (((int)threadIdx.x) < 16) {
      kernel_shared[(((int)threadIdx.x) + 560)] = kernel[(((((((int)blockIdx.x) * 36864) + (((((int)threadIdx.x) + 560) / 72) * 4608)) + (rc_outer_outer * 72)) + (((((int)threadIdx.x) + 56) / 3) * 3)) + ((((int)threadIdx.x) + 2) % 3))];
    }
    __syncthreads();
    for (int rc_outer_inner = 0; rc_outer_inner < 2; ++rc_outer_inner) {
      for (int ry_outer_inner = 0; ry_outer_inner < 3; ++ry_outer_inner) {
        for (int rc_inner = 0; rc_inner < 4; ++rc_inner) {
          for (int rx_inner = 0; rx_inner < 3; ++rx_inner) {
            conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp_shared[(((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7))] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 9)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 18)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 27)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 36)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 45)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
            conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp_shared[((((((rc_outer_inner * 324) + (rc_inner * 81)) + (ry_outer_inner * 9)) + rx_inner) + (((int)threadIdx.x) % 7)) + 54)] * kernel_shared[((((((((int)threadIdx.x) / 7) * 72) + (rc_outer_inner * 36)) + (rc_inner * 9)) + (ry_outer_inner * 3)) + rx_inner)]));
          }
        }
      }
    }
  }
  for (int i2_inner = 0; i2_inner < 7; ++i2_inner) {
    compute[((((((int)blockIdx.x) * 392) + ((((int)threadIdx.x) / 7) * 49)) + (i2_inner * 7)) + (((int)threadIdx.x) % 7))] = max((conv2d_nchw[i2_inner] + bias[((((int)blockIdx.x) * 8) + (((int)threadIdx.x) / 7))]), 0.000000e+00f);
  }
}

A more complicated example is to resume the search. In this case, we need to create the search policy and cost model by ourselves and resume the status of search policy and cost model with the log file. In the example below we resume the status and do more 5 trials.

def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5,
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )
    task.tune(tune_option, search_policy=search_policy)

    # Kill the measurement process
    del measure_ctx


resume_search(task, log_file)
Resume search:
/usr/local/lib/python3.7/dist-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
  warnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)
Get devices for measurement successfully!

Total running time of the script: ( 3 minutes 18.124 seconds)

Gallery generated by Sphinx-Gallery