Schedule Primitives in TVM

Author: Ziheng Jiang

TVM is a domain specific language for efficient kernel construction.

In this tutorial, we will show you how to schedule the computation by various primitives provided by TVM.

from __future__ import absolute_import, print_function

import tvm
from tvm import te
import numpy as np

There often exist several methods to compute the same result, however, different methods will result in different locality and performance. So TVM asks user to provide how to execute the computation called Schedule.

A Schedule is a set of transformation of computation that transforms the loop of computations in the program.

# declare some variables for use later
n = te.var('n')
m = te.var('m')

A schedule can be created from a list of ops, by default the schedule computes tensor in a serial manner in a row-major order.

# declare a matrix element-wise multiply
A = te.placeholder((m, n), name='A')
B = te.placeholder((m, n), name='B')
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')

s = te.create_schedule([C.op])
# lower will transform the computation from definition to the real
# callable function. With argument `simple_mode=True`, it will
# return you a readable C like statement, we use it here to print the
# schedule result.
print(tvm.lower(s, [A, B, C], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
             B: Buffer(B_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m, n], [stride_4: int32, stride_5: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, m) {
    for (j: int32, 0, n) {
      C_2[((i*stride) + (j*stride_1))] = ((float32*)A_2[((i*stride_4) + (j*stride_5))]*(float32*)B_2[((i*stride_2) + (j*stride_3))])
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

One schedule is composed by multiple stages, and one Stage represents schedule for one operation. We provide various methods to schedule every stage.

split

split can split a specified axis into two axises by factor.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i]*2, name='B')

s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_1: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  for (i.outer: int32, 0, floordiv((m + 31), 32)) {
    for (i.inner: int32, 0, 32) {
      if @tir.likely((((i.outer*32) + i.inner) < m), dtype=bool) {
        B_2[(((i.outer*32) + i.inner)*stride)] = ((float32*)A_2[(((i.outer*32) + i.inner)*stride_1)]*2f32)
      }
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

You can also split a axis by nparts, which splits the axis contrary with factor.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i], name='B')

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_1: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  for (i.outer: int32, 0, 32) {
    for (i.inner: int32, 0, floordiv((m + 31), 32)) {
      if @tir.likely(((i.inner + (i.outer*floordiv((m + 31), 32))) < m), dtype=bool) {
        B_2[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride)] = (float32*)A_2[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride_1)]
      }
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

tile

tile help you execute the computation tile by tile over two axises.

A = te.placeholder((m, n), name='A')
B = te.compute((m, n), lambda i, j: A[i, j], name='B')

s = te.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (i.inner: int32, 0, 10) {
        for (j.inner: int32, 0, 5) {
          if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) {
            if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) {
              B_2[((((i.outer*10) + i.inner)*stride) + (((j.outer*5) + j.inner)*stride_1))] = (float32*)A_2[((((i.outer*10) + i.inner)*stride_2) + (((j.outer*5) + j.inner)*stride_3))]
            }
          }
        }
      }
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

fuse

fuse can fuse two consecutive axises of one computation.

A = te.placeholder((m, n), name='A')
B = te.compute((m, n), lambda i, j: A[i, j], name='B')

s = te.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (i.inner.j.inner.fused: int32, 0, 50) {
        if @tir.likely((((i.outer*10) + floordiv(i.inner.j.inner.fused, 5)) < m), dtype=bool) {
          if @tir.likely((((j.outer*5) + floormod(i.inner.j.inner.fused, 5)) < n), dtype=bool) {
            B_2[((((i.outer*10) + floordiv(i.inner.j.inner.fused, 5))*stride) + (((j.outer*5) + floormod(i.inner.j.inner.fused, 5))*stride_1))] = (float32*)A_2[((((i.outer*10) + floordiv(i.inner.j.inner.fused, 5))*stride_2) + (((j.outer*5) + floormod(i.inner.j.inner.fused, 5))*stride_3))]
          }
        }
      }
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

reorder

reorder can reorder the axises in the specified order.

A = te.placeholder((m, n), name='A')
B = te.compute((m, n), lambda i, j: A[i, j], name='B')

s = te.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axises: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  for (i.inner: int32, 0, 10) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (i.outer: int32, 0, floordiv((m + 9), 10)) {
        for (j.inner: int32, 0, 5) {
          if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) {
            if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) {
              B_2[((((i.outer*10) + i.inner)*stride) + (((j.outer*5) + j.inner)*stride_1))] = (float32*)A_2[((((i.outer*10) + i.inner)*stride_2) + (((j.outer*5) + j.inner)*stride_3))]
            }
          }
        }
      }
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

bind

bind can bind a specified axis with a thread axis, often used in gpu programming.

A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda i: A[i] * 2, name='B')

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [n: int32], [stride: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [n], [stride_1: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 63), 64);
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
  if @tir.likely((((blockIdx.x*64) + threadIdx.x) < n), dtype=bool) {
    B_2[(((blockIdx.x*64) + threadIdx.x)*stride)] = ((float32*)A_2[(((blockIdx.x*64) + threadIdx.x)*stride_1)]*2f32)
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

compute_at

For a schedule that consists of multiple operators, TVM will compute tensors at the root separately by default.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i]+1, name='B')
C = te.compute((m,), lambda i: B[i]*2, name='C')

s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             B: Buffer(B_2: handle, float32, [m], [stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_2: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, m) {
    B_2[(i*stride_1)] = ((float32*)A_2[(i*stride_2)] + 1f32)
  }
  for (i_1: int32, 0, m) {
    C_2[(i_1*stride)] = ((float32*)B_2[(i_1*stride_1)]*2f32)
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

compute_at can move computation of B into the first axis of computation of C.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i]+1, name='B')
C = te.compute((m,), lambda i: B[i]*2, name='C')

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             B: Buffer(B_2: handle, float32, [m], [stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_2: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, m) {
    B_2[(i*stride_1)] = ((float32*)A_2[(i*stride_2)] + 1f32)
    C_2[(i*stride)] = ((float32*)B_2[(i*stride_1)]*2f32)
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

compute_inline

compute_inline can mark one stage as inline, then the body of computation will be expanded and inserted at the address where the tensor is required.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i]+1, name='B')
C = te.compute((m,), lambda i: B[i]*2, name='C')

s = te.create_schedule(C.op)
s[B].compute_inline()
print(tvm.lower(s, [A, B, C], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             C: Buffer(C_2: handle, float32, [m], [stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_2: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, m) {
    C_2[(i*stride_1)] = (((float32*)A_2[(i*stride_2)] + 1f32)*2f32)
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

compute_root

compute_root can move computation of one stage to the root.

A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i]+1, name='B')
C = te.compute((m,), lambda i: B[i]*2, name='C')

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
print(tvm.lower(s, [A, B, C], simple_mode=True))

Out:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: handle, float32, [m: int32], [stride: int32], type="auto"),
             B: Buffer(B_2: handle, float32, [m], [stride_1: int32], type="auto"),
             A: Buffer(A_2: handle, float32, [m], [stride_2: int32], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, m) {
    B_2[(i*stride_1)] = ((float32*)A_2[(i*stride_2)] + 1f32)
  }
  for (i_1: int32, 0, m) {
    C_2[(i_1*stride)] = ((float32*)B_2[(i_1*stride_1)]*2f32)
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

Summary

This tutorial provides an introduction to schedule primitives in tvm, which permits users schedule the computation easily and flexibly.

In order to get a good performance kernel implementation, the general workflow often is:

  • Describe your computation via series of operations.

  • Try to schedule the computation with primitives.

  • Compile and run to see the performance difference.

  • Adjust your schedule according the running result.

Gallery generated by Sphinx-Gallery