.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "how_to/work_with_schedules/schedule_primitives.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_work_with_schedules_schedule_primitives.py: .. _schedule_primitives: Schedule Primitives in TVM ========================== **Author**: `Ziheng Jiang `_ TVM is a domain specific language for efficient kernel construction. In this tutorial, we will show you how to schedule the computation by various primitives provided by TVM. .. GENERATED FROM PYTHON SOURCE LINES 29-36 .. code-block:: default from __future__ import absolute_import, print_function import tvm from tvm import te import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 42-50 There often exist several methods to compute the same result, however, different methods will result in different locality and performance. So TVM asks user to provide how to execute the computation called **Schedule**. A **Schedule** is a set of transformation of computation that transforms the loop of computations in the program. .. GENERATED FROM PYTHON SOURCE LINES 51-56 .. code-block:: default # declare some variables for use later n = te.var("n") m = te.var("m") .. GENERATED FROM PYTHON SOURCE LINES 57-59 A schedule can be created from a list of ops, by default the schedule computes tensor in a serial manner in a row-major order. .. GENERATED FROM PYTHON SOURCE LINES 59-72 .. code-block:: default # declare a matrix element-wise multiply A = te.placeholder((m, n), name="A") B = te.placeholder((m, n), name="B") C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C") s = te.create_schedule([C.op]) # lower will transform the computation from definition to the real # callable function. With argument `simple_mode=True`, it will # return you a readable C like statement, we use it here to print the # schedule result. print(tvm.lower(s, [A, B, C], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} { for (i: int32, 0, m) { for (j: int32, 0, n) { C[((i*stride_2) + (j*stride_5))] = (A[((i*stride) + (j*stride_3))]*B[((i*stride_1) + (j*stride_4))]) } } } .. GENERATED FROM PYTHON SOURCE LINES 73-76 One schedule is composed by multiple stages, and one **Stage** represents schedule for one operation. We provide various methods to schedule every stage. .. GENERATED FROM PYTHON SOURCE LINES 78-82 split ----- :code:`split` can split a specified axis into two axes by :code:`factor`. .. GENERATED FROM PYTHON SOURCE LINES 82-89 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i] * 2, name="B") s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=32) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto")} { for (i.outer: int32, 0, floordiv((m + 31), 32)) { for (i.inner: int32, 0, 32) { if @tir.likely((((i.outer*32) + i.inner) < m), dtype=bool) { let cse_var_1: int32 = ((i.outer*32) + i.inner) B[(cse_var_1*stride_1)] = (A[(cse_var_1*stride)]*2f32) } } } } .. GENERATED FROM PYTHON SOURCE LINES 90-92 You can also split a axis by :code:`nparts`, which splits the axis contrary with :code:`factor`. .. GENERATED FROM PYTHON SOURCE LINES 92-99 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i], name="B") s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], nparts=32) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto")} { for (i.outer: int32, 0, 32) { for (i.inner: int32, 0, floordiv((m + 31), 32)) { if @tir.likely(((i.inner + (i.outer*floordiv((m + 31), 32))) < m), dtype=bool) { B[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride_1)] = A[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride)] } } } } .. GENERATED FROM PYTHON SOURCE LINES 100-104 tile ---- :code:`tile` help you execute the computation tile by tile over two axes. .. GENERATED FROM PYTHON SOURCE LINES 104-111 .. code-block:: default A = te.placeholder((m, n), name="A") B = te.compute((m, n), lambda i, j: A[i, j], name="B") s = te.create_schedule(B.op) xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} { for (i.outer: int32, 0, floordiv((m + 9), 10)) { for (j.outer: int32, 0, floordiv((n + 4), 5)) { for (i.inner: int32, 0, 10) { if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) { for (j.inner: int32, 0, 5) { if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) { let cse_var_2: int32 = ((j.outer*5) + j.inner) let cse_var_1: int32 = ((i.outer*10) + i.inner) B[((cse_var_1*stride_1) + (cse_var_2*stride_3))] = A[((cse_var_1*stride) + (cse_var_2*stride_2))] } } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 112-115 fuse ---- :code:`fuse` can fuse two consecutive axes of one computation. .. GENERATED FROM PYTHON SOURCE LINES 115-125 .. code-block:: default A = te.placeholder((m, n), name="A") B = te.compute((m, n), lambda i, j: A[i, j], name="B") s = te.create_schedule(B.op) # tile to four axes first: (i.outer, j.outer, i.inner, j.inner) xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5) # then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused) fused = s[B].fuse(xi, yi) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} { for (i.outer: int32, 0, floordiv((m + 9), 10)) { for (j.outer: int32, 0, floordiv((n + 4), 5)) { for (i.inner.j.inner.fused: int32, 0, 50) { if @tir.likely((((i.outer*10) + floordiv(i.inner.j.inner.fused, 5)) < m), dtype=bool) { if @tir.likely((((j.outer*5) + floormod(i.inner.j.inner.fused, 5)) < n), dtype=bool) { let cse_var_2: int32 = ((j.outer*5) + floormod(i.inner.j.inner.fused, 5)) let cse_var_1: int32 = ((i.outer*10) + floordiv(i.inner.j.inner.fused, 5)) B[((cse_var_1*stride_1) + (cse_var_2*stride_3))] = A[((cse_var_1*stride) + (cse_var_2*stride_2))] } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 126-129 reorder ------- :code:`reorder` can reorder the axes in the specified order. .. GENERATED FROM PYTHON SOURCE LINES 129-139 .. code-block:: default A = te.placeholder((m, n), name="A") B = te.compute((m, n), lambda i, j: A[i, j], name="B") s = te.create_schedule(B.op) # tile to four axes first: (i.outer, j.outer, i.inner, j.inner) xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5) # then reorder the axes: (i.inner, j.outer, i.outer, j.inner) s[B].reorder(xi, yo, xo, yi) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} { for (i.inner: int32, 0, 10) { for (j.outer: int32, 0, floordiv((n + 4), 5)) { for (i.outer: int32, 0, floordiv((m + 9), 10)) { if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) { for (j.inner: int32, 0, 5) { if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) { let cse_var_2: int32 = ((j.outer*5) + j.inner) let cse_var_1: int32 = ((i.outer*10) + i.inner) B[((cse_var_1*stride_1) + (cse_var_2*stride_3))] = A[((cse_var_1*stride) + (cse_var_2*stride_2))] } } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 140-144 bind ---- :code:`bind` can bind a specified axis with a thread axis, often used in gpu programming. .. GENERATED FROM PYTHON SOURCE LINES 144-153 .. code-block:: default A = te.placeholder((n,), name="A") B = te.compute(A.shape, lambda i: A[i] * 2, name="B") s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], factor=64) s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) print(tvm.lower(s, [A, B], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")} buffer_map = {A_1: A, B_1: B} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 63), 64); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64; if @tir.likely((((blockIdx.x*64) + threadIdx.x) < n), dtype=bool) { B[(((blockIdx.x*64) + threadIdx.x)*stride_1)] = (A[(((blockIdx.x*64) + threadIdx.x)*stride)]*2f32) } } .. GENERATED FROM PYTHON SOURCE LINES 154-158 compute_at ---------- For a schedule that consists of multiple operators, TVM will compute tensors at the root separately by default. .. GENERATED FROM PYTHON SOURCE LINES 158-165 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i] + 1, name="B") C = te.compute((m,), lambda i: B[i] * 2, name="C") s = te.create_schedule(C.op) print(tvm.lower(s, [A, B, C], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} { for (i: int32, 0, m) { B[(i*stride_1)] = (A[(i*stride)] + 1f32) } for (i_1: int32, 0, m) { C[(i_1*stride_2)] = (B[(i_1*stride_1)]*2f32) } } .. GENERATED FROM PYTHON SOURCE LINES 166-168 :code:`compute_at` can move computation of `B` into the first axis of computation of `C`. .. GENERATED FROM PYTHON SOURCE LINES 168-176 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i] + 1, name="B") C = te.compute((m,), lambda i: B[i] * 2, name="C") s = te.create_schedule(C.op) s[B].compute_at(s[C], C.op.axis[0]) print(tvm.lower(s, [A, B, C], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} { for (i: int32, 0, m) { B[(i*stride_1)] = (A[(i*stride)] + 1f32) C[(i*stride_2)] = (B[(i*stride_1)]*2f32) } } .. GENERATED FROM PYTHON SOURCE LINES 177-182 compute_inline -------------- :code:`compute_inline` can mark one stage as inline, then the body of computation will be expanded and inserted at the address where the tensor is required. .. GENERATED FROM PYTHON SOURCE LINES 182-190 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i] + 1, name="B") C = te.compute((m,), lambda i: B[i] * 2, name="C") s = te.create_schedule(C.op) s[B].compute_inline() print(tvm.lower(s, [A, B, C], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} { for (i: int32, 0, m) { C[(i*stride_2)] = ((A[(i*stride)] + 1f32)*2f32) } } .. GENERATED FROM PYTHON SOURCE LINES 191-194 compute_root ------------ :code:`compute_root` can move computation of one stage to the root. .. GENERATED FROM PYTHON SOURCE LINES 194-203 .. code-block:: default A = te.placeholder((m,), name="A") B = te.compute((m,), lambda i: A[i] + 1, name="B") C = te.compute((m,), lambda i: B[i] * 2, name="C") s = te.create_schedule(C.op) s[B].compute_at(s[C], C.op.axis[0]) s[B].compute_root() print(tvm.lower(s, [A, B, C], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"), B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"), C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} { for (i: int32, 0, m) { B[(i*stride_1)] = (A[(i*stride)] + 1f32) } for (i_1: int32, 0, m) { C[(i_1*stride_2)] = (B[(i_1*stride_1)]*2f32) } } .. GENERATED FROM PYTHON SOURCE LINES 204-217 Summary ------- This tutorial provides an introduction to schedule primitives in tvm, which permits users schedule the computation easily and flexibly. In order to get a good performance kernel implementation, the general workflow often is: - Describe your computation via series of operations. - Try to schedule the computation with primitives. - Compile and run to see the performance difference. - Adjust your schedule according the running result. .. _sphx_glr_download_how_to_work_with_schedules_schedule_primitives.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: schedule_primitives.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: schedule_primitives.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_