Note
Click here to download the full example code
Auto-scheduling Matrix Multiplication for CPU¶
Author: Lianmin Zheng, Chengfan Jia
This is a tutorial on how to use the auto-scheduler for CPUs.
Different from the template-based autotvm which relies on manual templates to define the search space, the auto-scheduler does not require any templates. Users only need to write the computation declaration without any schedule commands or templates. The auto-scheduler can automatically generate a large search space and find a good schedule in the space.
We use matrix multiplication as an example in this tutorial.
Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a if
__name__ == "__main__":
block.
import os
import numpy as np
import tvm
from tvm import te, auto_scheduler
Define the computation¶
To begin with, let us define the computation of a matmul with bias add. The function should return the list of input/output tensors. From these tensors, the auto-scheduler can get the whole computational graph.
@auto_scheduler.register_workload
def matmul_add(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
Create the search task¶
We then create a search task with N=L=M=1024 and dtype=”float32” If your machine supports avx instructions, you can
replace “llvm” below with “llvm -mcpu=core-avx2” to enable AVX2
replace “llvm” below with “llvm -mcpu=skylake-avx512” to enable AVX-512
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
Out:
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])
Next, we set parameters for the auto-scheduler.
num_measure_trials
is the number of measurement trials we can use during the search. We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a good value for the search to converge. You can do more trials according to your time budget.In addition, we use
RecordToFile
to dump measurement records into a file matmul.json. The measurement records can be used to query the history best, resume the search, and do more analyses later.see
auto_scheduler.TuningOptions
for more parameters
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
Run the search¶
Now we get all inputs ready. Pretty simple, isn’t it? We can kick off the search and let the auto-scheduler do its magic. After some measurement trials, we can load the best schedule from the log file and apply it.
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
Out:
We can lower the schedule to see the IR after auto-scheduling. The auto-scheduler correctly performs optimizations including multi-level tiling, layout transformation, parallelization, vectorization, unrolling, and operator fusion.
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
Out:
Lowered TIR:
primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {out: Buffer(out_2: Pointer(float32), float32, [1024, 1024], []),
C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], []),
A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}
buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out} {
attr [auto_scheduler_layout_transform: Pointer(float32)] "storage_scope" = "global";
allocate(auto_scheduler_layout_transform, float32, [1048576]) {
for (ax0.ax1.fused.ax2.fused: int32, 0, 128) "parallel" {
for (ax4: int32, 0, 256) {
for (ax6: int32, 0, 4) {
for (ax7: int32, 0, 8) {
auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*8192) + (ax4*32)) + (ax6*8)) + ax7)] = (float32*)B_2[((((ax4*4096) + (ax6*1024)) + (ax0.ax1.fused.ax2.fused*8)) + ax7)]
}
}
}
}
for (i.outer.outer.j.outer.outer.fused: int32, 0, 16384) "parallel" {
attr [matmul: Pointer(float32x8)] "storage_scope" = "global";
allocate(matmul, float32x8, [4]);
for (i.outer.inner: int32, 0, 2) {
matmul[ramp(0, 1, 8)] = broadcast(0f32, 8)
matmul[ramp(8, 1, 8)] = broadcast(0f32, 8)
matmul[ramp(16, 1, 8)] = broadcast(0f32, 8)
matmul[ramp(24, 1, 8)] = broadcast(0f32, 8)
for (k.outer: int32, 0, 256) {
for (k.inner: int32, 0, 4) {
matmul[ramp(0, 1, 8)] = ((float32x8*)matmul[ramp(0, 1, 8)] + (broadcast((float32*)A_2[((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))
matmul[ramp(8, 1, 8)] = ((float32x8*)matmul[ramp(8, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 1024)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))
matmul[ramp(16, 1, 8)] = ((float32x8*)matmul[ramp(16, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 2048)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))
matmul[ramp(24, 1, 8)] = ((float32x8*)matmul[ramp(24, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 3072)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))
}
}
for (i.inner: int32, 0, 4) {
out_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)] = ((float32x8*)matmul[ramp((i.inner*8), 1, 8)] + (float32x8*)C_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)])
}
}
}
}
}
Check correctness and evaluate performance¶
We build the binary and check its correctness and performance.
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
ctx = tvm.cpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
b_tvm = tvm.nd.array(b_np, ctx=ctx)
c_tvm = tvm.nd.array(c_np, ctx=ctx)
out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
func(a_tvm, b_tvm, c_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
Out:
Execution time of this operator: 43.375 ms
Using the record file¶
During the search, all measurement records are dumped into the record file “matmul.json”. The measurement records can be used to re-apply search results, resume the search, and perform other analyses.
Here is an example where we load the best schedule from a file, and print the equivalent python schedule API. This can be used for debugging and learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
print(task.print_best(log_file))
Out:
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)
A more complicated example is to resume the search. In this case, we need to create the search policy and cost model by ourselves and resume the status of search policy and cost model with the log file. In the example below we resume the status and do more 5 trials.
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
resume_search(task, log_file)
Out:
Resume search: