Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Scan and Recurrent Kernel
Author: Tianqi Chen
This is an introduction material on how to do recurrent computing in TVM. Recurrent computing is a typical pattern in neural networks.
from __future__ import absolute_import, print_function
import tvm
import tvm.testing
from tvm import te
import numpy as np
TVM supports a scan operator to describe symbolic loop. The following scan op computes cumsum over columns of X.
The scan is carried over the highest dimension of the tensor.
s_state
is a placeholder that describes the transition state of the scan.
s_init
describes how we can initialize the first k timesteps.
Here since s_init’s first dimension is 1, it describes how we initialize
The state at first timestep.
s_update
describes how to update the value at timestep t. The update
value can refer back to the values of previous timestep via state placeholder.
Note that while it is invalid to refer to s_state
at current or later timestep.
The scan takes in state placeholder, initial value and update description.
It is also recommended(although not necessary) to list the inputs to the scan cell.
The result of the scan is a tensor, giving the result of s_state
after the
update over the time domain.
Schedule the Scan Cell
We can schedule the body of the scan by scheduling the update and init part separately. Note that it is invalid to schedule the first iteration dimension of the update part. To split on the time iteration, user can schedule on scan_op.scan_axis instead.
s = te.create_schedule(s_scan.op)
num_thread = 256
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
print(tvm.lower(s, [X, s_scan], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(X: T.handle, scan: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
scan_1 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
blockIdx_x = T.env_thread("blockIdx.x")
threadIdx_x = T.env_thread("threadIdx.x")
scan_2 = T.Buffer((scan_1.strides[0] * m,), data=scan_1.data, buffer_type="auto")
X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
with T.launch_thread(blockIdx_x, (n + 255) // 256):
T.launch_thread(threadIdx_x, 256)
if T.likely(blockIdx_x * 256 + threadIdx_x < n):
scan_2[(blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] = X_2[(blockIdx_x * 256 + threadIdx_x) * X_1.strides[1]]
for scan_idx in range(m - 1):
T.launch_thread(blockIdx_x, (n + 255) // 256)
T.launch_thread(threadIdx_x, 256)
if T.likely(blockIdx_x * 256 + threadIdx_x < n):
cse_var_1: T.int32 = scan_idx + 1
scan_2[cse_var_1 * scan_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] = scan_2[scan_idx * scan_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] + X_2[cse_var_1 * X_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * X_1.strides[1]]
Build and Verify
We can build the scan kernel like other TVM kernels, here we use numpy to verify the correctness of the result.
fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
dev = tvm.cuda(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)
fscan(a, b)
tvm.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))
Multi-Stage Scan Cell
In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell.
The following lines demonstrate a scan with two stage operations in the scan cell.
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1")
s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2")
s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])
These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop.
s = te.create_schedule(s_scan.op)
xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)
s[s_update_s1].compute_at(s[s_update_s2], xo)
print(tvm.lower(s, [X, s_scan], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(X: T.handle, scan: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
scan_1 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
s1 = T.allocate([32], "float32", "global")
scan_2 = T.Buffer((scan_1.strides[0] * m,), data=scan_1.data, buffer_type="auto")
X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
for i in range(n):
scan_2[i * scan_1.strides[1]] = X_2[i * X_1.strides[1]]
for scan_idx, i_outer in T.grid(m - 1, (n + 31) // 32):
s1_1 = T.Buffer((32,), data=s1)
for i in range(32):
if T.likely(i_outer * 32 + i < n):
s1_1[i] = scan_2[scan_idx * scan_1.strides[0] + (i_outer * 32 + i) * scan_1.strides[1]] * T.float32(2.0)
for i_inner in range(32):
if T.likely(i_outer * 32 + i_inner < n):
cse_var_2: T.int32 = scan_idx + 1
cse_var_1: T.int32 = i_outer * 32 + i_inner
scan_2[cse_var_2 * scan_1.strides[0] + cse_var_1 * scan_1.strides[1]] = s1_1[i_inner] + X_2[cse_var_2 * X_1.strides[0] + cse_var_1 * X_1.strides[1]]
Multiple States
For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states.
m = te.var("m")
n = te.var("n")
l = te.var("l")
X = te.placeholder((m, n), name="X")
s_state1 = te.placeholder((m, n))
s_state2 = te.placeholder((m, l))
s_init1 = te.compute((1, n), lambda _, i: X[0, i])
s_init2 = te.compute((1, l), lambda _, i: 0.0)
s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])
s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])
s_scan1, s_scan2 = tvm.te.scan(
[s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]
)
s = te.create_schedule(s_scan1.op)
print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(X: T.handle, scan: T.handle, scan_1: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
scan_2 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
l = T.int32()
scan_3 = T.match_buffer(scan_1, (m, l), strides=("stride", "stride"), buffer_type="auto")
scan_4 = T.Buffer((scan_2.strides[0] * m,), data=scan_2.data, buffer_type="auto")
X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
for i in range(n):
scan_4[i * scan_2.strides[1]] = X_2[i * X_1.strides[1]]
scan_5 = T.Buffer((scan_3.strides[0] * m,), data=scan_3.data, buffer_type="auto")
for i in range(l):
scan_5[i * scan_3.strides[1]] = T.float32(0.0)
for scan_idx in range(m - 1):
for i in range(n):
cse_var_1: T.int32 = scan_idx + 1
scan_4[cse_var_1 * scan_2.strides[0] + i * scan_2.strides[1]] = scan_4[scan_idx * scan_2.strides[0] + i * scan_2.strides[1]] + X_2[cse_var_1 * X_1.strides[0] + i * X_1.strides[1]]
for i in range(l):
scan_5[(scan_idx + 1) * scan_3.strides[0] + i * scan_3.strides[1]] = scan_5[scan_idx * scan_3.strides[0] + i * scan_3.strides[1]] + scan_4[scan_idx * scan_2.strides[0]]
Summary
This tutorial provides a walk through of scan primitive.
Describe scan with init and update.
Schedule the scan cells as normal schedule.
For complicated workload, use multiple states and steps in scan cell.