Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Schedule Primitives in TVM¶
Author: Ziheng Jiang
TVM is a domain specific language for efficient kernel construction.
In this tutorial, we will show you how to schedule the computation by various primitives provided by TVM.
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
There often exist several methods to compute the same result, however, different methods will result in different locality and performance. So TVM asks user to provide how to execute the computation called Schedule.
A Schedule is a set of transformation of computation that transforms the loop of computations in the program.
A schedule can be created from a list of ops, by default the schedule computes tensor in a serial manner in a row-major order.
# declare a matrix element-wise multiply
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
s = te.create_schedule([C.op])
# lower will transform the computation from definition to the real
# callable function. With argument `simple_mode=True`, it will
# return you a readable C like statement, we use it here to print the
# schedule result.
print(tvm.lower(s, [A, B, C], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i, j in T.grid(m, n):
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]
One schedule is composed by multiple stages, and one Stage represents schedule for one operation. We provide various methods to schedule every stage.
split¶
split
can split a specified axis into two axes by
factor
.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
for i_outer, i_inner in T.grid((m + 31) // 32, 32):
if T.likely(i_outer * 32 + i_inner < m):
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
cse_var_1: T.int32 = i_outer * 32 + i_inner
B_2[cse_var_1 * B_1.strides[0]] = A_2[cse_var_1 * A_1.strides[0]] * T.float32(2)
You can also split a axis by nparts
, which splits the axis
contrary with factor
.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
for i_outer, i_inner in T.grid(32, (m + 31) // 32):
if T.likely(i_inner + i_outer * ((m + 31) // 32) < m):
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[(i_inner + i_outer * ((m + 31) // 32)) * B_1.strides[0]] = A_2[(i_inner + i_outer * ((m + 31) // 32)) * A_1.strides[0]]
tile¶
tile
help you execute the computation tile by tile over two
axes.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]]
fuse¶
fuse
can fuse two consecutive axes of one computation.
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
print(tvm.lower(s, [A, B], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i_outer, j_outer, i_inner_j_inner_fused in T.grid((m + 9) // 10, (n + 4) // 5, 50):
if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m):
if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n):
cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]]
reorder¶
reorder
can reorder the axes in the specified order.
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
print(tvm.lower(s, [A, B], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i_inner, j_outer, i_outer in T.grid(10, (n + 4) // 5, (m + 9) // 10):
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]]
bind¶
bind
can bind a specified axis with a thread axis, often used
in gpu programming.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
n = T.int32()
A_1 = T.match_buffer(A, (n,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (n,), strides=("stride",), buffer_type="auto")
blockIdx_x = T.launch_thread("blockIdx.x", (n + 63) // 64)
threadIdx_x = T.launch_thread("threadIdx.x", 64)
if T.likely(blockIdx_x * 64 + threadIdx_x < n):
B_2 = T.Buffer((B_1.strides[0] * n,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
B_2[(blockIdx_x * 64 + threadIdx_x) * B_1.strides[0]] = A_2[(blockIdx_x * 64 + threadIdx_x) * A_1.strides[0]] * T.float32(2)
compute_at¶
For a schedule that consists of multiple operators, TVM will compute tensors at the root separately by default.
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
for i in range(m):
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
for i in range(m):
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)
compute_at
can move computation of B into the first axis
of computation of C.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
for i in range(m):
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)
compute_inline¶
compute_inline
can mark one stage as inline, then the body of
computation will be expanded and inserted at the address where the
tensor is required.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
for i in range(m):
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
C_2[i * C_1.strides[0]] = (A_2[i * A_1.strides[0]] + T.float32(1)) * T.float32(2)
compute_root¶
compute_root
can move computation of one stage to the root.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m = T.int32()
A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
for i in range(m):
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
for i in range(m):
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)
Summary¶
This tutorial provides an introduction to schedule primitives in tvm, which permits users schedule the computation easily and flexibly.
In order to get a good performance kernel implementation, the general workflow often is:
Describe your computation via series of operations.
Try to schedule the computation with primitives.
Compile and run to see the performance difference.
Adjust your schedule according the running result.