Scan and Recurrent Kernel

Author: Tianqi Chen

This is an introduction material on how to do recurrent computing in TVM. Recurrent computing is a typical pattern in neural networks.

from __future__ import absolute_import, print_function


import tvm
import tvm.testing
from tvm import te
import numpy as np

TVM supports a scan operator to describe symbolic loop. The following scan op computes cumsum over columns of X.

The scan is carried over the highest dimension of the tensor. s_state is a placeholder that describes the transition state of the scan. s_init describes how we can initialize the first k timesteps. Here since s_init’s first dimension is 1, it describes how we initialize The state at first timestep.

s_update describes how to update the value at timestep t. The update value can refer back to the values of previous timestep via state placeholder. Note that while it is invalid to refer to s_state at current or later timestep.

The scan takes in state placeholder, initial value and update description. It is also recommended(although not necessary) to list the inputs to the scan cell. The result of the scan is a tensor, giving the result of s_state after the update over the time domain.

m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])

Schedule the Scan Cell

We can schedule the body of the scan by scheduling the update and init part separately. Note that it is invalid to schedule the first iteration dimension of the update part. To split on the time iteration, user can schedule on scan_op.scan_axis instead.

s = te.create_schedule(s_scan.op)
num_thread = 256
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
print(tvm.lower(s, [X, s_scan], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(X: T.handle, scan: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
        scan_1 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
        blockIdx_x = T.env_thread("blockIdx.x")
        threadIdx_x = T.env_thread("threadIdx.x")
        scan_2 = T.Buffer((scan_1.strides[0] * m,), data=scan_1.data, buffer_type="auto")
        X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
        with T.launch_thread(blockIdx_x, (n + 255) // 256):
            T.launch_thread(threadIdx_x, 256)
            if T.likely(blockIdx_x * 256 + threadIdx_x < n):
                scan_2[(blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] = X_2[(blockIdx_x * 256 + threadIdx_x) * X_1.strides[1]]
        for scan_idx in range(m - 1):
            T.launch_thread(blockIdx_x, (n + 255) // 256)
            T.launch_thread(threadIdx_x, 256)
            if T.likely(blockIdx_x * 256 + threadIdx_x < n):
                cse_var_1: T.int32 = scan_idx + 1
                scan_2[cse_var_1 * scan_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] = scan_2[scan_idx * scan_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * scan_1.strides[1]] + X_2[cse_var_1 * X_1.strides[0] + (blockIdx_x * 256 + threadIdx_x) * X_1.strides[1]]

Build and Verify

We can build the scan kernel like other TVM kernels, here we use numpy to verify the correctness of the result.

fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
dev = tvm.cuda(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev)
fscan(a, b)
tvm.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0))

Multi-Stage Scan Cell

In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell.

The following lines demonstrate a scan with two stage operations in the scan cell.

m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1")
s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2")
s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])

These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop.

s = te.create_schedule(s_scan.op)
xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32)
s[s_update_s1].compute_at(s[s_update_s2], xo)
print(tvm.lower(s, [X, s_scan], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(X: T.handle, scan: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
        scan_1 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
        s1 = T.allocate([32], "float32", "global")
        scan_2 = T.Buffer((scan_1.strides[0] * m,), data=scan_1.data, buffer_type="auto")
        X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
        for i in range(n):
            scan_2[i * scan_1.strides[1]] = X_2[i * X_1.strides[1]]
        for scan_idx, i_outer in T.grid(m - 1, (n + 31) // 32):
            s1_1 = T.Buffer((32,), data=s1)
            for i in range(32):
                if T.likely(i_outer * 32 + i < n):
                    s1_1[i] = scan_2[scan_idx * scan_1.strides[0] + (i_outer * 32 + i) * scan_1.strides[1]] * T.float32(2.0)
            for i_inner in range(32):
                if T.likely(i_outer * 32 + i_inner < n):
                    cse_var_2: T.int32 = scan_idx + 1
                    cse_var_1: T.int32 = i_outer * 32 + i_inner
                    scan_2[cse_var_2 * scan_1.strides[0] + cse_var_1 * scan_1.strides[1]] = s1_1[i_inner] + X_2[cse_var_2 * X_1.strides[0] + cse_var_1 * X_1.strides[1]]

Multiple States

For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states.

m = te.var("m")
n = te.var("n")
l = te.var("l")
X = te.placeholder((m, n), name="X")
s_state1 = te.placeholder((m, n))
s_state2 = te.placeholder((m, l))
s_init1 = te.compute((1, n), lambda _, i: X[0, i])
s_init2 = te.compute((1, l), lambda _, i: 0.0)
s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])
s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])
s_scan1, s_scan2 = tvm.te.scan(
    [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]
)
s = te.create_schedule(s_scan1.op)
print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(X: T.handle, scan: T.handle, scan_1: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        X_1 = T.match_buffer(X, (m, n), strides=("stride", "stride"), buffer_type="auto")
        scan_2 = T.match_buffer(scan, (m, n), strides=("stride", "stride"), buffer_type="auto")
        l = T.int32()
        scan_3 = T.match_buffer(scan_1, (m, l), strides=("stride", "stride"), buffer_type="auto")
        scan_4 = T.Buffer((scan_2.strides[0] * m,), data=scan_2.data, buffer_type="auto")
        X_2 = T.Buffer((X_1.strides[0] * m,), data=X_1.data, buffer_type="auto")
        for i in range(n):
            scan_4[i * scan_2.strides[1]] = X_2[i * X_1.strides[1]]
        scan_5 = T.Buffer((scan_3.strides[0] * m,), data=scan_3.data, buffer_type="auto")
        for i in range(l):
            scan_5[i * scan_3.strides[1]] = T.float32(0.0)
        for scan_idx in range(m - 1):
            for i in range(n):
                cse_var_1: T.int32 = scan_idx + 1
                scan_4[cse_var_1 * scan_2.strides[0] + i * scan_2.strides[1]] = scan_4[scan_idx * scan_2.strides[0] + i * scan_2.strides[1]] + X_2[cse_var_1 * X_1.strides[0] + i * X_1.strides[1]]
            for i in range(l):
                scan_5[(scan_idx + 1) * scan_3.strides[0] + i * scan_3.strides[1]] = scan_5[scan_idx * scan_3.strides[0] + i * scan_3.strides[1]] + scan_4[scan_idx * scan_2.strides[0]]

Summary

This tutorial provides a walk through of scan primitive.

  • Describe scan with init and update.

  • Schedule the scan cells as normal schedule.

  • For complicated workload, use multiple states and steps in scan cell.

Gallery generated by Sphinx-Gallery