{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%%shell\n# Installs the latest dev build of TVM from PyPI, with CUDA enabled. To use this,\n# you must request a Google Colab instance with a GPU by going to Runtime ->\n# Change runtime type -> Hardware accelerator -> GPU. If you wish to build from\n# source, see https://tvm.apache.org/docs/install/from_source.html\npip install tlcpack-nightly-cu113 --pre -f https://tlcpack.ai/wheels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Reduction\n**Author**: [Tianqi Chen](https://tqchen.github.io)\n\nThis is an introduction material on how to do reduction in TVM.\nAssociative reduction operators like sum/max/min are typical\nconstruction blocks of linear algebra operations.\n\nIn this tutorial, we will demonstrate how to do reduction in TVM.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import absolute_import, print_function\n\n\nimport tvm\nimport tvm.testing\nfrom tvm import te\nimport numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Describe Sum of Rows\nAssume we want to compute sum of rows as our example.\nIn numpy semantics this can be written as :code:`B = numpy.sum(A, axis=1)`\n\nThe following lines describe the row sum operation.\nTo create a reduction formula, we declare a reduction axis using\n:any:`te.reduce_axis`. :any:`te.reduce_axis` takes in the range of reductions.\n:any:`te.sum` takes in the expression to be reduced as well as the reduction\naxis and compute the sum of value over all k in the declared range.\n\nThe equivalent C code is as follows:\n\n```c\nfor (int i = 0; i < n; ++i) {\n B[i] = 0;\n for (int k = 0; k < m; ++k) {\n B[i] = B[i] + A[i][k];\n }\n}\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = te.var(\"n\")\nm = te.var(\"m\")\nA = te.placeholder((n, m), name=\"A\")\nk = te.reduce_axis((0, m), \"k\")\nB = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name=\"B\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Schedule the Reduction\nThere are several ways to schedule a reduction.\nBefore doing anything, let us print out the IR code of default schedule.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "s = te.create_schedule(B.op)\nprint(tvm.lower(s, [A, B], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can find that the IR code is quite like the C code.\nThe reduction axis is similar to a normal axis, it can be splitted.\n\nIn the following code we split both the row axis of B as well\naxis by different factors. The result is a nested reduction.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)\nxo, xi = s[B].split(B.op.axis[0], factor=32)\nprint(tvm.lower(s, [A, B], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we are building a GPU kernel, we can bind the rows of B to GPU threads.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "s[B].bind(xo, te.thread_axis(\"blockIdx.x\"))\ns[B].bind(xi, te.thread_axis(\"threadIdx.x\"))\nprint(tvm.lower(s, [A, B], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reduction Factoring and Parallelization\nOne problem of building a reduction is that we cannot simply\nparallelize over the reduction axis. We need to divide the computation\nof the reduction, store the local reduction result in a temporal array\nbefore doing a reduction over the temp array.\n\nThe rfactor primitive does such rewrite of the computation.\nIn the following schedule, the result of B is written to a temporary\nresult B.rf. The factored dimension becomes the first dimension of B.rf.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "s = te.create_schedule(B.op)\nko, ki = s[B].split(B.op.reduce_axis[0], factor=16)\nBF = s.rfactor(B, ki)\nprint(tvm.lower(s, [A, B], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The scheduled operator of B also get rewritten to be sum over\nthe first axis of reduced result of B.f\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(s[B].op.body)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross Thread Reduction\nWe can now parallelize over the factored axis.\nHere the reduction axis of B is marked to be a thread.\nTVM allows reduction axis to be marked as thread if it is the only\naxis in reduction and cross thread reduction is possible in the device.\n\nThis is indeed the case after the factoring.\nWe can directly compute BF at the reduction axis as well.\nThe final generated kernel will divide the rows by blockIdx.x and threadIdx.y\ncolumns by threadIdx.x and finally do a cross thread reduction over threadIdx.x\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "xo, xi = s[B].split(s[B].op.axis[0], factor=32)\ns[B].bind(xo, te.thread_axis(\"blockIdx.x\"))\ns[B].bind(xi, te.thread_axis(\"threadIdx.y\"))\ntx = te.thread_axis(\"threadIdx.x\")\ns[B].bind(s[B].op.reduce_axis[0], tx)\ns[BF].compute_at(s[B], s[B].op.reduce_axis[0])\ns[B].set_store_predicate(tx.var.equal(0))\nfcuda = tvm.build(s, [A, B], \"cuda\")\nprint(fcuda.imported_modules[0].get_source())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify the correctness of result kernel by comparing it to numpy.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nn = 128\ndev = tvm.cuda(0)\na = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), dev)\nb = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev)\nfcuda(a, b)\ntvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Describe Convolution via 2D Reduction\nIn TVM, we can describe convolution via 2D reduction in a simple way.\nHere is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1].\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = te.var(\"n\")\nInput = te.placeholder((n, n), name=\"Input\")\nFilter = te.placeholder((3, 3), name=\"Filter\")\ndi = te.reduce_axis((0, 3), name=\"di\")\ndj = te.reduce_axis((0, 3), name=\"dj\")\nOutput = te.compute(\n (n - 2, n - 2),\n lambda i, j: te.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),\n name=\"Output\",\n)\ns = te.create_schedule(Output.op)\nprint(tvm.lower(s, [Input, Filter, Output], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n## Define General Commutative Reduction Operation\nBesides the built-in reduction operations like :any:`te.sum`,\n:any:`tvm.te.min` and :any:`tvm.te.max`, you can also define your\ncommutative reduction operation by :any:`te.comm_reducer`.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = te.var(\"n\")\nm = te.var(\"m\")\nproduct = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name=\"product\")\nA = te.placeholder((n, m), name=\"A\")\nk = te.reduce_axis((0, m), name=\"k\")\nB = te.compute((n,), lambda i: product(A[i, k], axis=k), name=\"B\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
Sometimes we would like to perform reduction that involves multiple\n values like :code:`argmax`, which can be done by tuple inputs.\n See `reduction-with-tuple-inputs` for more detail.