.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "how_to/work_with_schedules/reduction.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_work_with_schedules_reduction.py: Reduction ========= **Author**: `Tianqi Chen `_ This is an introduction material on how to do reduction in TVM. Associative reduction operators like sum/max/min are typical construction blocks of linear algebra operations. In this tutorial, we will demonstrate how to do reduction in TVM. .. GENERATED FROM PYTHON SOURCE LINES 28-36 .. code-block:: default from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 42-64 Describe Sum of Rows -------------------- Assume we want to compute sum of rows as our example. In numpy semantics this can be written as :code:`B = numpy.sum(A, axis=1)` The following lines describe the row sum operation. To create a reduction formula, we declare a reduction axis using :any:`te.reduce_axis`. :any:`te.reduce_axis` takes in the range of reductions. :any:`te.sum` takes in the expression to be reduced as well as the reduction axis and compute the sum of value over all k in the declared range. The equivalent C code is as follows: .. code-block:: c for (int i = 0; i < n; ++i) { B[i] = 0; for (int k = 0; k < m; ++k) { B[i] = B[i] + A[i][k]; } } .. GENERATED FROM PYTHON SOURCE LINES 64-70 .. code-block:: default n = te.var("n") m = te.var("m") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") .. GENERATED FROM PYTHON SOURCE LINES 71-76 Schedule the Reduction ---------------------- There are several ways to schedule a reduction. Before doing anything, let us print out the IR code of default schedule. .. GENERATED FROM PYTHON SOURCE LINES 76-79 .. code-block:: default s = te.create_schedule(B.op) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} { for (i: int32, 0, n) { B[(i*stride_1)] = 0f32 for (k: int32, 0, m) { B[(i*stride_1)] = (B[(i*stride_1)] + A[((i*stride) + (k*stride_2))]) } } } .. GENERATED FROM PYTHON SOURCE LINES 80-86 You can find that the IR code is quite like the C code. The reduction axis is similar to a normal axis, it can be splitted. In the following code we split both the row axis of B as well axis by different factors. The result is a nested reduction. .. GENERATED FROM PYTHON SOURCE LINES 86-90 .. code-block:: default ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) xo, xi = s[B].split(B.op.axis[0], factor=32) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} { for (i.outer: int32, 0, floordiv((n + 31), 32)) { for (i.inner: int32, 0, 32) { if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) { B[(((i.outer*32) + i.inner)*stride_1)] = 0f32 } if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) { for (k.outer: int32, 0, floordiv((m + 15), 16)) { for (k.inner: int32, 0, 16) { if @tir.likely((((k.outer*16) + k.inner) < m), dtype=bool) { let cse_var_1: int32 = ((i.outer*32) + i.inner) B[(cse_var_1*stride_1)] = (B[(cse_var_1*stride_1)] + A[((cse_var_1*stride) + (((k.outer*16) + k.inner)*stride_2))]) } } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 91-92 If we are building a GPU kernel, we can bind the rows of B to GPU threads. .. GENERATED FROM PYTHON SOURCE LINES 92-96 .. code-block:: default s[B].bind(xo, te.thread_axis("blockIdx.x")) s[B].bind(xi, te.thread_axis("threadIdx.x")) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 31), 32); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 { if @tir.likely((((blockIdx.x*32) + threadIdx.x) < n), dtype=bool) { B[(((blockIdx.x*32) + threadIdx.x)*stride_1)] = 0f32 } for (k.outer: int32, 0, floordiv((m + 15), 16)) { for (k.inner: int32, 0, 16) { if @tir.likely((((blockIdx.x*32) + threadIdx.x) < n), dtype=bool) { if @tir.likely((((k.outer*16) + k.inner) < m), dtype=bool) { B[(((blockIdx.x*32) + threadIdx.x)*stride_1)] = (B[(((blockIdx.x*32) + threadIdx.x)*stride_1)] + A[((((blockIdx.x*32) + threadIdx.x)*stride) + (((k.outer*16) + k.inner)*stride_2))]) } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 97-108 Reduction Factoring and Parallelization --------------------------------------- One problem of building a reduction is that we cannot simply parallelize over the reduction axis. We need to divide the computation of the reduction, store the local reduction result in a temporal array before doing a reduction over the temp array. The rfactor primitive does such rewrite of the computation. In the following schedule, the result of B is written to a temporary result B.rf. The factored dimension becomes the first dimension of B.rf. .. GENERATED FROM PYTHON SOURCE LINES 108-113 .. code-block:: default s = te.create_schedule(B.op) ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) BF = s.rfactor(B, ki) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} { allocate(B.rf: Pointer(global float32), float32, [(n*16)]), storage_scope = global { for (k.inner: int32, 0, 16) { for (i: int32, 0, n) { B.rf_1: Buffer(B.rf, float32, [(16*n)], [])[((k.inner*n) + i)] = 0f32 for (k.outer: int32, 0, floordiv((m + 15), 16)) { if @tir.likely((((k.outer*16) + k.inner) < m), dtype=bool) { B.rf_1[((k.inner*n) + i)] = (B.rf_1[((k.inner*n) + i)] + A[((i*stride) + (((k.outer*16) + k.inner)*stride_2))]) } } } } for (ax0: int32, 0, n) { B[(ax0*stride_1)] = 0f32 for (k.inner.v: int32, 0, 16) { B[(ax0*stride_1)] = (B[(ax0*stride_1)] + B.rf_1[((k.inner.v*n) + ax0)]) } } } } .. GENERATED FROM PYTHON SOURCE LINES 114-117 The scheduled operator of B also get rewritten to be sum over the first axis of reduced result of B.f .. GENERATED FROM PYTHON SOURCE LINES 117-119 .. code-block:: default print(s[B].op.body) .. rst-class:: sphx-glr-script-out .. code-block:: none [reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[B.rf[k.inner.v, ax0]], init=[], axis=[iter_var(k.inner.v, range(min=0, ext=16))], where=(bool)1, value_index=0)] .. GENERATED FROM PYTHON SOURCE LINES 120-132 Cross Thread Reduction ---------------------- We can now parallelize over the factored axis. Here the reduction axis of B is marked to be a thread. TVM allows reduction axis to be marked as thread if it is the only axis in reduction and cross thread reduction is possible in the device. This is indeed the case after the factoring. We can directly compute BF at the reduction axis as well. The final generated kernel will divide the rows by blockIdx.x and threadIdx.y columns by threadIdx.x and finally do a cross thread reduction over threadIdx.x .. GENERATED FROM PYTHON SOURCE LINES 132-142 .. code-block:: default xo, xi = s[B].split(s[B].op.axis[0], factor=32) s[B].bind(xo, te.thread_axis("blockIdx.x")) s[B].bind(xi, te.thread_axis("threadIdx.y")) tx = te.thread_axis("threadIdx.x") s[B].bind(s[B].op.reduce_axis[0], tx) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) s[B].set_store_predicate(tx.var.equal(0)) fcuda = tvm.build(s, [A, B], "cuda") print(fcuda.imported_modules[0].get_source()) .. rst-class:: sphx-glr-script-out .. code-block:: none #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) #define __shfl_sync(mask, var, lane, width) \ __shfl((var), (lane), (width)) #define __shfl_down_sync(mask, var, offset, width) \ __shfl_down((var), (offset), (width)) #define __shfl_up_sync(mask, var, offset, width) \ __shfl_up((var), (offset), (width)) #endif #ifdef _WIN32 using uint = unsigned int; using uchar = unsigned char; using ushort = unsigned short; using int64_t = long long; using uint64_t = unsigned long long; #else #define uint unsigned int #define uchar unsigned char #define ushort unsigned short #define int64_t long long #define uint64_t unsigned long long #endif extern "C" __global__ void __launch_bounds__(512) default_function_kernel0(float* __restrict__ A, float* __restrict__ B, int m, int n, int stride, int stride_1, int stride_2) { float B_rf[1]; float red_buf0[1]; B_rf[0] = 0.000000e+00f; for (int k_outer = 0; k_outer < (m >> 4); ++k_outer) { if (((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) < n) { B_rf[0] = (B_rf[0] + A[((((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride) + (((k_outer * 16) + ((int)threadIdx.x)) * stride_1))]); } } for (int k_outer_1 = 0; k_outer_1 < (((m & 15) + 15) >> 4); ++k_outer_1) { if (((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) < n) { if (((((m >> 4) * 16) + (k_outer_1 * 16)) + ((int)threadIdx.x)) < m) { B_rf[0] = (B_rf[0] + A[((((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride) + (((((m >> 4) * 16) + (k_outer_1 * 16)) + ((int)threadIdx.x)) * stride_1))]); } } } uint mask[1]; float t0[1]; red_buf0[0] = B_rf[0]; mask[0] = (__activemask() & ((uint)(65535 << (((int)threadIdx.y) * 16)))); t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32); red_buf0[0] = (red_buf0[0] + t0[0]); t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32); red_buf0[0] = (red_buf0[0] + t0[0]); t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32); red_buf0[0] = (red_buf0[0] + t0[0]); t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32); red_buf0[0] = (red_buf0[0] + t0[0]); red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], (((int)threadIdx.y) * 16), 32); if (((int)threadIdx.x) == 0) { B[(((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride_2)] = red_buf0[0]; } } .. GENERATED FROM PYTHON SOURCE LINES 143-145 Verify the correctness of result kernel by comparing it to numpy. .. GENERATED FROM PYTHON SOURCE LINES 145-152 .. code-block:: default nn = 128 dev = tvm.cuda(0) a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev) fcuda(a, b) tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-4) .. GENERATED FROM PYTHON SOURCE LINES 153-158 Describe Convolution via 2D Reduction ------------------------------------- In TVM, we can describe convolution via 2D reduction in a simple way. Here is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1]. .. GENERATED FROM PYTHON SOURCE LINES 158-171 .. code-block:: default n = te.var("n") Input = te.placeholder((n, n), name="Input") Filter = te.placeholder((3, 3), name="Filter") di = te.reduce_axis((0, 3), name="di") dj = te.reduce_axis((0, 3), name="dj") Output = te.compute( (n - 2, n - 2), lambda i, j: te.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]), name="Output", ) s = te.create_schedule(Output.op) print(tvm.lower(s, [Input, Filter, Output], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(Input_1: handle, Filter_1: handle, Output_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {Input: Buffer(Input_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), Filter: Buffer(Filter_2: Pointer(float32), float32, [9], []), Output: Buffer(Output_2: Pointer(float32), float32, [((n - 2)*(n - 2))], [])} buffer_map = {Input_1: Input, Filter_1: Filter, Output_1: Output} preflattened_buffer_map = {Input_1: Input_3: Buffer(Input_2, float32, [n, n], [stride, stride_1: int32], type="auto"), Filter_1: Filter_3: Buffer(Filter_2, float32, [3, 3], []), Output_1: Output_3: Buffer(Output_2, float32, [(n - 2), (n - 2)], [])} { for (i: int32, 0, (n - 2)) { for (j: int32, 0, (n - 2)) { Output[((i*(n - 2)) + j)] = 0f32 for (di: int32, 0, 3) { for (dj: int32, 0, 3) { Output[((i*(n - 2)) + j)] = (Output[((i*(n - 2)) + j)] + (Input[(((i + di)*stride) + ((j + dj)*stride_1))]*Filter[((di*3) + dj)])) } } } } } .. GENERATED FROM PYTHON SOURCE LINES 172-180 .. _general-reduction: Define General Commutative Reduction Operation ---------------------------------------------- Besides the built-in reduction operations like :any:`te.sum`, :any:`tvm.te.min` and :any:`tvm.te.max`, you can also define your commutative reduction operation by :any:`te.comm_reducer`. .. GENERATED FROM PYTHON SOURCE LINES 180-188 .. code-block:: default n = te.var("n") m = te.var("m") product = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name="product") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: product(A[i, k], axis=k), name="B") .. GENERATED FROM PYTHON SOURCE LINES 189-194 .. note:: Sometimes we would like to perform reduction that involves multiple values like :code:`argmax`, which can be done by tuple inputs. See :ref:`reduction-with-tuple-inputs` for more detail. .. GENERATED FROM PYTHON SOURCE LINES 196-203 Summary ------- This tutorial provides a walk through of reduction schedule. - Describe reduction with reduce_axis. - Use rfactor to factor out axis if we need parallelism. - Define new reduction operation by :any:`te.comm_reducer` .. _sphx_glr_download_how_to_work_with_schedules_reduction.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: reduction.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: reduction.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_