Optimizing Operators with Auto-scheduling

Author: Lianmin Zheng, Chengfan Jia

In this tutorial, we will show how TVM’s Auto Scheduling feature can find optimal schedules without the need for writing a custom template.

Different from the template-based AutoTVM which relies on manual templates to define the search space, the auto-scheduler does not require any templates. Users only need to write the computation declaration without any schedule commands or templates. The auto-scheduler can automatically generate a large search space and find a good schedule in the space.

We use matrix multiplication as an example in this tutorial.

Note

Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a if __name__ == "__main__": block.

import numpy as np
import tvm
from tvm import te, auto_scheduler

Defining the Matrix Multiplication

To start, we define a matrix multiplication with a bias addition. Note that this uses standard operations available in TVMs Tensor Expression language. The major difference is the use of the register_workload decorator at the top of the function definition. The function should return a list of input/output tensors. From these tensors, the auto-scheduler can get the whole computational graph.

@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")

    return [A, B, C, out]

Create the search task

With the function defined, we can now create the task for the auto_scheduler to search against. We specify the particular parameters for this matrix multiplication, in this case a multiplication of two square matrices of size 1024x1024. We then create a search task with N=L=M=1024 and dtype=”float32”

Improve performance with custom targets

In order for TVM to take full advantage of specific hardware platforms, you will want to manually specify your CPU capabilities. For example:

  • replace llvm below with llvm -mcpu=core-avx2 to enable AVX2

  • replace llvm below with llvm -mcpu=skylake-avx512 to enable AVX-512

target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])

Set Parameters for Auto-Scheduler

Next, we set parameters for the auto-scheduler.

  • num_measure_trials is the number of measurement trials we can use during the search. We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a good value for the search to converge. You can do more trials according to your time budget.

  • In addition, we use RecordToFile to log measurement records into a file matmul.json. The measurement records can be used to query the history best, resume the search, and do more analyses later.

  • see TuningOptions for more parameters

log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

Inspecting the Optimized Schedule

We can lower the schedule to see the IR after auto-scheduling. The auto-scheduler correctly performs optimizations including multi-level tiling, layout transformation, parallelization, vectorization, unrolling, and operator fusion.

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
Lowered TIR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], []),
             out: Buffer(out_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], []), out_1: out_3: Buffer(out_2, float32, [1024, 1024], [])} {
  allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global {
    for (ax0.ax1.fused.ax2.fused: int32, 0, 128) "parallel" {
      for (ax4: int32, 0, 256) {
        for (ax6: int32, 0, 4) {
          for (ax7: int32, 0, 8) {
            auto_scheduler_layout_transform_1: Buffer(auto_scheduler_layout_transform, float32, [1048576], [])[((((ax0.ax1.fused.ax2.fused*8192) + (ax4*32)) + (ax6*8)) + ax7)] = B[((((ax4*4096) + (ax6*1024)) + (ax0.ax1.fused.ax2.fused*8)) + ax7)]
          }
        }
      }
    }
    for (i.outer.outer.j.outer.outer.fused: int32, 0, 16384) "parallel" {
      allocate(matmul: Pointer(global float32x8), float32x8, [4]), storage_scope = global;
      for (i.outer.inner: int32, 0, 2) {
        matmul_1: Buffer(matmul, float32x8, [4], [])[0] = broadcast(0f32, 8)
        matmul_1[1] = broadcast(0f32, 8)
        matmul_1[2] = broadcast(0f32, 8)
        matmul_1[3] = broadcast(0f32, 8)
        for (k.outer: int32, 0, 256) {
          for (k.inner: int32, 0, 4) {
            let cse_var_2: int32 = (((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8))
            let cse_var_1: int32 = ((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner)
             {
              matmul_1[0] = (matmul_1[0] + (broadcast(A[cse_var_1], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
              matmul_1[1] = (matmul_1[1] + (broadcast(A[(cse_var_1 + 1024)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
              matmul_1[2] = (matmul_1[2] + (broadcast(A[(cse_var_1 + 2048)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
              matmul_1[3] = (matmul_1[3] + (broadcast(A[(cse_var_1 + 3072)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
            }
          }
        }
        for (i.inner: int32, 0, 4) {
          let cse_var_3: int32 = ((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8))
          out[ramp(cse_var_3, 1, 8)] = (matmul_1[i.inner] + C[ramp(cse_var_3, 1, 8)])
        }
      }
    }
  }
}

Check correctness and evaluate performance

We build the binary and check its correctness and performance.

func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np

dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
Execution time of this operator: 97.366 ms

Using the record file

During the search, all measurement records are logged into the record file matmul.json`. The measurement records can be used to re-apply search results, resume the search, and perform other analyses.

Here is an example where we load the best schedule from a file, and print the equivalent python schedule API. This can be used for debugging and learning the behavior of the auto-scheduler.

print("Equivalent python schedule:")
print(task.print_best(log_file))
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)

A more complicated example is to resume the search. In this case, we need to create the search policy and cost model by ourselves and resume the status of search policy and cost model with the log file. In the example below we resume the status and do more 5 trials.

def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
    )
    task.tune(tune_option, search_policy=search_policy)


resume_search(task, log_file)
Resume search:
/usr/local/lib/python3.7/dist-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
  warnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)

Final Notes and Summary

In this tutorial, we have shown how to use the TVM Auto-Scheduler to automatically optimize a matrix multiplication, without the need to specify a search template. It ends a series of examples that starts from the Tensor Expression (TE) language that demonstrates how TVM can optimize computational operations.

Gallery generated by Sphinx-Gallery