.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "how_to/work_with_schedules/scan.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_work_with_schedules_scan.py: Scan and Recurrent Kernel ========================= **Author**: `Tianqi Chen `_ This is an introduction material on how to do recurrent computing in TVM. Recurrent computing is a typical pattern in neural networks. .. GENERATED FROM PYTHON SOURCE LINES 25-33 .. code-block:: default from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 39-57 TVM supports a scan operator to describe symbolic loop. The following scan op computes cumsum over columns of X. The scan is carried over the highest dimension of the tensor. :code:`s_state` is a placeholder that describes the transition state of the scan. :code:`s_init` describes how we can initialize the first k timesteps. Here since s_init's first dimension is 1, it describes how we initialize The state at first timestep. :code:`s_update` describes how to update the value at timestep t. The update value can refer back to the values of previous timestep via state placeholder. Note that while it is invalid to refer to :code:`s_state` at current or later timestep. The scan takes in state placeholder, initial value and update description. It is also recommended(although not necessary) to list the inputs to the scan cell. The result of the scan is a tensor, giving the result of :code:`s_state` after the update over the time domain. .. GENERATED FROM PYTHON SOURCE LINES 57-65 .. code-block:: default m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X]) .. GENERATED FROM PYTHON SOURCE LINES 66-73 Schedule the Scan Cell ---------------------- We can schedule the body of the scan by scheduling the update and init part separately. Note that it is invalid to schedule the first iteration dimension of the update part. To split on the time iteration, user can schedule on scan_op.scan_axis instead. .. GENERATED FROM PYTHON SOURCE LINES 73-85 .. code-block:: default s = te.create_schedule(s_scan.op) num_thread = 256 block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis("threadIdx.x") xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x) s[s_init].bind(xi, thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_x) s[s_update].bind(xi, thread_x) print(tvm.lower(s, [X, s_scan], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(X_1: handle, scan_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {X: Buffer(X_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), scan: Buffer(scan_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {X_1: X, scan_1: scan} preflattened_buffer_map = {X_1: X_3: Buffer(X_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), scan_1: scan_3: Buffer(scan_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 255), 256); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256; if @tir.likely((((blockIdx.x*256) + threadIdx.x) < n), dtype=bool) { scan[(((blockIdx.x*256) + threadIdx.x)*stride_3)] = X[(((blockIdx.x*256) + threadIdx.x)*stride_2)] } for (scan.idx: int32, 0, (m - 1)) { attr [IterVar(blockIdx.x, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 255), 256); attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256; if @tir.likely((((blockIdx.x*256) + threadIdx.x) < n), dtype=bool) { let cse_var_1: int32 = (scan.idx + 1) scan[((cse_var_1*stride_1) + (((blockIdx.x*256) + threadIdx.x)*stride_3))] = (scan[((scan.idx*stride_1) + (((blockIdx.x*256) + threadIdx.x)*stride_3))] + X[((cse_var_1*stride) + (((blockIdx.x*256) + threadIdx.x)*stride_2))]) } } } .. GENERATED FROM PYTHON SOURCE LINES 86-91 Build and Verify ---------------- We can build the scan kernel like other TVM kernels, here we use numpy to verify the correctness of the result. .. GENERATED FROM PYTHON SOURCE LINES 91-101 .. code-block:: default fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan") dev = tvm.cuda(0) n = 1024 m = 10 a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros((m, n), dtype=s_scan.dtype), dev) fscan(a, b) tvm.testing.assert_allclose(b.numpy(), np.cumsum(a_np, axis=0)) .. GENERATED FROM PYTHON SOURCE LINES 102-111 Multi-Stage Scan Cell --------------------- In the above example we described the scan cell using one Tensor computation stage in s_update. It is possible to use multiple Tensor stages in the scan cell. The following lines demonstrate a scan with two stage operations in the scan cell. .. GENERATED FROM PYTHON SOURCE LINES 111-120 .. code-block:: default m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1") s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2") s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X]) .. GENERATED FROM PYTHON SOURCE LINES 121-125 These intermediate tensors can also be scheduled normally. To ensure correctness, TVM creates a group constraint to forbid the body of scan to be compute_at locations outside the scan loop. .. GENERATED FROM PYTHON SOURCE LINES 125-130 .. code-block:: default s = te.create_schedule(s_scan.op) xo, xi = s[s_update_s2].split(s_update_s2.op.axis[1], factor=32) s[s_update_s1].compute_at(s[s_update_s2], xo) print(tvm.lower(s, [X, s_scan], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(X_1: handle, scan_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {X: Buffer(X_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), scan: Buffer(scan_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {X_1: X, scan_1: scan} preflattened_buffer_map = {X_1: X_3: Buffer(X_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), scan_1: scan_3: Buffer(scan_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} { allocate(s1: Pointer(global float32), float32, [32]), storage_scope = global { for (i: int32, 0, n) { scan[(i*stride_3)] = X[(i*stride_2)] } for (scan.idx: int32, 0, (m - 1)) { for (i.outer: int32, 0, floordiv((n + 31), 32)) { for (i_1: int32, 0, 32) { if @tir.likely((((i.outer*32) + i_1) < n), dtype=bool) { s1_1: Buffer(s1, float32, [32], [])[i_1] = (scan[((scan.idx*stride_1) + (((i.outer*32) + i_1)*stride_3))]*2f32) } } for (i.inner: int32, 0, 32) { if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) { let cse_var_2: int32 = (scan.idx + 1) let cse_var_1: int32 = ((i.outer*32) + i.inner) scan[((cse_var_2*stride_1) + (cse_var_1*stride_3))] = (s1_1[i.inner] + X[((cse_var_2*stride) + (cse_var_1*stride_2))]) } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 131-137 Multiple States --------------- For complicated applications like RNN, we might need more than one recurrent state. Scan support multiple recurrent states. The following example demonstrates how we can build recurrence with two states. .. GENERATED FROM PYTHON SOURCE LINES 137-153 .. code-block:: default m = te.var("m") n = te.var("n") l = te.var("l") X = te.placeholder((m, n), name="X") s_state1 = te.placeholder((m, n)) s_state2 = te.placeholder((m, l)) s_init1 = te.compute((1, n), lambda _, i: X[0, i]) s_init2 = te.compute((1, l), lambda _, i: 0.0) s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i]) s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0]) s_scan1, s_scan2 = tvm.te.scan( [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X] ) s = te.create_schedule(s_scan1.op) print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(X_1: handle, scan_2: handle, scan_3: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {X: Buffer(X_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), scan: Buffer(scan_4: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), scan_1: Buffer(scan_5: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {X_1: X, scan_2: scan, scan_3: scan_1} preflattened_buffer_map = {X_1: X_3: Buffer(X_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), scan_2: scan_6: Buffer(scan_4, float32, [m, n], [stride_1, stride_4: int32], type="auto"), scan_3: scan_7: Buffer(scan_5, float32, [m, l: int32], [stride_2, stride_5: int32], type="auto")} { for (i: int32, 0, n) { scan[(i*stride_4)] = X[(i*stride_3)] } for (i_1: int32, 0, l) { scan_1[(i_1*stride_5)] = 0f32 } for (scan.idx: int32, 0, (m - 1)) { for (i_2: int32, 0, n) { let cse_var_1: int32 = (scan.idx + 1) scan[((cse_var_1*stride_1) + (i_2*stride_4))] = (scan[((scan.idx*stride_1) + (i_2*stride_4))] + X[((cse_var_1*stride) + (i_2*stride_3))]) } for (i_3: int32, 0, l) { scan_1[(((scan.idx + 1)*stride_2) + (i_3*stride_5))] = (scan_1[((scan.idx*stride_2) + (i_3*stride_5))] + scan[(scan.idx*stride_1)]) } } } .. GENERATED FROM PYTHON SOURCE LINES 154-161 Summary ------- This tutorial provides a walk through of scan primitive. - Describe scan with init and update. - Schedule the scan cells as normal schedule. - For complicated workload, use multiple states and steps in scan cell. .. _sphx_glr_download_how_to_work_with_schedules_scan.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: scan.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: scan.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_