.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_work_with_schedules_scan.py: Scan and Recurrent Kernel ========================= **Author**: `Tianqi Chen `_ This is an introduction material on how to do recurrent computing in TVM. Recurrent computing is a typical pattern in neural networks. .. code-block:: default from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te import numpy as np TVM supports a scan operator to describe symbolic loop. The following scan op computes cumsum over columns of X. The scan is carried over the highest dimension of the tensor. :code:`s_state` is a placeholder that describes the transition state of the scan. :code:`s_init` describes how we can initialize the first k timesteps. Here since s_init's first dimension is 1, it describes how we initialize The state at first timestep. :code:`s_update` describes how to update the value at timestep t. The update value can refer back to the values of previous timestep via state placeholder. Note that while it is invalid to refer to :code:`s_state` at current or later timestep. The scan takes in state placeholder, initial value and update description. It is also recommended(although not necessary) to list the inputs to the scan cell. The result of the scan is a tensor, giving the result of :code:`s_state` after the update over the time domain. .. code-block:: default m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X]) Schedule the Scan Cell ---------------------- We can schedule the body of the scan by scheduling the update and init part seperately. Note that it is invalid to schedule the first iteration dimension of the update part. To split on the time iteration, user can schedule on scan_op.scan_axis instead. .. code-block:: default s = te.create_schedule(s_scan.op) num_thread = 256 block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis("threadIdx.x") xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x) s[s_init].bind(xi, thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_x) s[s_update].bind(xi, thread_x) print(tvm.lower(s, [X, s_scan], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(X_1: handle, scan_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {scan: Buffer(scan_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"), X: Buffer(X_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type="auto")} buffer_map = {X_1: X, scan_1: scan} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 255), 256); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256; if @tir.likely((((blockIdx.x*256) + threadIdx.x) < n), dtype=bool) { scan_2[(((blockIdx.x*256) + threadIdx.x)*stride_1)] = (float32*)X_2[(((blockIdx.x*256) + threadIdx.x)*stride_3)] } for (scan.idx: int32, 0, (m - 1)) { attr [IterVar(blockIdx.x, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 255), 256); attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256; if @tir.likely((((blockIdx.x*256) + threadIdx.x) < n), dtype=bool) { scan_2[(((scan.idx + 1)*stride) + (((blockIdx.x*256) + threadIdx.x)*stride_1))] = ((float32*)scan_2[((scan.idx*stride) + (((blockIdx.x*256) + threadIdx.x)*stride_1))] + (float32*)X_2[(((scan.idx + 1)*stride_2) + (((blockIdx.x*256) + threadIdx.x)*stride_3))]) } } } Build and Verify ---------------- We can build the scan kernel like other TVM kernels, here we use numpy to verify the correctness of the result. .. code-block:: default fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan") dev = tvm.cuda(0) n = 1024 m = 10 a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev) fscan(a, b) tvm.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0)) Multi-Stage Scan Cell --------------------- In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell. The following lines demonstrate a scan with two stage operations in the scan cell. .. code-block:: default m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1") s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2") s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X]) These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop. .. code-block:: default s = te.create_schedule(s_scan.op) xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32) s[s_update_s1].compute_at(s[s_update_s2], xo) print(tvm.lower(s, [X, s_scan], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(X_1: handle, scan_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {scan: Buffer(scan_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"), X: Buffer(X_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type="auto")} buffer_map = {X_1: X, scan_1: scan} { allocate(s1: Pointer(global float32), float32, [32]), storage_scope = global { for (i: int32, 0, n) { scan_2[(i*stride_1)] = (float32*)X_2[(i*stride_3)] } for (scan.idx: int32, 0, (m - 1)) { for (i.outer: int32, 0, floordiv((n + 31), 32)) { for (i_1: int32, 0, 32) { if @tir.likely((((i.outer*32) + i_1) < n), dtype=bool) { s1[i_1] = ((float32*)scan_2[((scan.idx*stride) + (((i.outer*32) + i_1)*stride_1))]*2f32) } } for (i.inner: int32, 0, 32) { if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) { scan_2[(((scan.idx + 1)*stride) + (((i.outer*32) + i.inner)*stride_1))] = ((float32*)s1[i.inner] + (float32*)X_2[(((scan.idx + 1)*stride_2) + (((i.outer*32) + i.inner)*stride_3))]) } } } } } } Multiple States --------------- For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states. .. code-block:: default m = te.var("m") n = te.var("n") l = te.var("l") X = te.placeholder((m, n), name="X") s_state1 = te.placeholder((m, n)) s_state2 = te.placeholder((m, l)) s_init1 = te.compute((1, n), lambda _, i: X[0, i]) s_init2 = te.compute((1, l), lambda _, i: 0.0) s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i]) s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0]) s_scan1, s_scan2 = tvm.te.scan( [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X] ) s = te.create_schedule(s_scan1.op) print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(X_1: handle, scan_2: handle, scan_3: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {scan_1: Buffer(scan_4: Pointer(float32), float32, [m: int32, l: int32], [stride: int32, stride_1: int32], type="auto"), X: Buffer(X_2: Pointer(float32), float32, [m, n: int32], [stride_2: int32, stride_3: int32], type="auto"), scan: Buffer(scan_5: Pointer(float32), float32, [m, n], [stride_4: int32, stride_5: int32], type="auto")} buffer_map = {X_1: X, scan_2: scan, scan_3: scan_1} { for (i: int32, 0, n) { scan_5[(i*stride_5)] = (float32*)X_2[(i*stride_3)] } for (i_1: int32, 0, l) { scan_4[(i_1*stride_1)] = 0f32 } for (scan.idx: int32, 0, (m - 1)) { for (i_2: int32, 0, n) { scan_5[(((scan.idx + 1)*stride_4) + (i_2*stride_5))] = ((float32*)scan_5[((scan.idx*stride_4) + (i_2*stride_5))] + (float32*)X_2[(((scan.idx + 1)*stride_2) + (i_2*stride_3))]) } for (i_3: int32, 0, l) { scan_4[(((scan.idx + 1)*stride) + (i_3*stride_1))] = ((float32*)scan_4[((scan.idx*stride) + (i_3*stride_1))] + (float32*)scan_5[(scan.idx*stride_4)]) } } } Summary ------- This tutorial provides a walk through of scan primitive. - Describe scan with init and update. - Schedule the scan cells as normal schedule. - For complicated workload, use multiple states and steps in scan cell. .. _sphx_glr_download_how_to_work_with_schedules_scan.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: scan.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: scan.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_