Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Compute and Reduce with Tuple Inputs¶
Author: Ziheng Jiang
Often we want to compute multiple outputs with the same shape within
a single loop or perform reduction that involves multiple values like
argmax
. These problems can be addressed by tuple inputs.
In this tutorial, we will introduce the usage of tuple inputs in TVM.
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
Describe Batchwise Computation¶
For operators which have the same shape, we can put them together as
the inputs of te.compute
, if we want them to be scheduled
together in the next schedule procedure.
n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
A1 = te.placeholder((m, n), name="A1")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")
# The generated IR code would be:
s = te.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A0: T.handle, A1: T.handle, B: T.handle, B_1: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A0_1 = T.match_buffer(A0, (m, n), strides=("stride", "stride"), buffer_type="auto")
A1_1 = T.match_buffer(A1, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_2 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_3 = T.match_buffer(B_1, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i, j in T.grid(m, n):
B_4 = T.Buffer((B_2.strides[0] * m,), data=B_2.data, buffer_type="auto")
A0_2 = T.Buffer((A0_1.strides[0] * m,), data=A0_1.data, buffer_type="auto")
B_4[i * B_2.strides[0] + j * B_2.strides[1]] = A0_2[i * A0_1.strides[0] + j * A0_1.strides[1]] + T.float32(2)
B_5 = T.Buffer((B_3.strides[0] * m,), data=B_3.data, buffer_type="auto")
A1_2 = T.Buffer((A1_1.strides[0] * m,), data=A1_1.data, buffer_type="auto")
B_5[i * B_3.strides[0] + j * B_3.strides[1]] = A1_2[i * A1_1.strides[0] + j * A1_1.strides[1]] * T.float32(3)
Describe Reduction with Collaborative Inputs¶
Sometimes, we require multiple inputs to express some reduction
operators, and the inputs will collaborate together, e.g. argmax
.
In the reduction procedure, argmax
need to compare the value of
operands, also need to keep the index of operand. It can be expressed
with te.comm_reducer()
as below:
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts
# two types as inputs.
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
# describe the reduction computation
m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="int32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
# the generated IR code would be:
s = te.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(idx: T.handle, val: T.handle, T: T.handle, T_1: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
idx_1 = T.match_buffer(idx, (m, n), "int32", strides=("stride", "stride"), buffer_type="auto")
val_1 = T.match_buffer(val, (m, n), "int32", strides=("stride", "stride"), buffer_type="auto")
T_2 = T.match_buffer(T, (m,), "int32", strides=("stride",), buffer_type="auto")
T_3 = T.match_buffer(T_1, (m,), "int32", strides=("stride",), buffer_type="auto")
for i in range(m):
T_4 = T.Buffer((T_2.strides[0] * m,), "int32", data=T_2.data, buffer_type="auto")
T_4[i * T_2.strides[0]] = -1
T_5 = T.Buffer((T_3.strides[0] * m,), "int32", data=T_3.data, buffer_type="auto")
T_5[i * T_3.strides[0]] = -2147483648
for k in range(n):
val_2 = T.Buffer((val_1.strides[0] * m,), "int32", data=val_1.data, buffer_type="auto")
idx_2 = T.Buffer((idx_1.strides[0] * m,), "int32", data=idx_1.data, buffer_type="auto")
T_4[i * T_2.strides[0]] = T.if_then_else(val_2[i * val_1.strides[0] + k * val_1.strides[1]] <= T_5[i * T_3.strides[0]], T_4[i * T_2.strides[0]], idx_2[i * idx_1.strides[0] + k * idx_1.strides[1]])
T_5[i * T_3.strides[0]] = T.if_then_else(val_2[i * val_1.strides[0] + k * val_1.strides[1]] <= T_5[i * T_3.strides[0]], T_5[i * T_3.strides[0]], val_2[i * val_1.strides[0] + k * val_1.strides[1]])
Note
For ones who are not familiar with reduction, please refer to Define General Commutative Reduction Operation.
Schedule Operation with Tuple Inputs¶
It is worth mentioning that although you will get multiple outputs with one batch operation, but they can only be scheduled together in terms of operation.
n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name="B")
A1 = te.placeholder((m, n), name="A1")
C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name="C")
s = te.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# as you can see in the below generated IR code:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A0: T.handle, A1: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A0_1 = T.match_buffer(A0, (m, n), strides=("stride", "stride"), buffer_type="auto")
A1_1 = T.match_buffer(A1, (m, n), strides=("stride", "stride"), buffer_type="auto")
C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_v0 = T.allocate([n], "float32", "global")
B_v1 = T.allocate([n], "float32", "global")
for i in range(m):
B_v0_1 = T.Buffer((n,), data=B_v0)
for j in range(n):
A0_2 = T.Buffer((A0_1.strides[0] * m,), data=A0_1.data, buffer_type="auto")
B_v0_1[j] = A0_2[i * A0_1.strides[0] + j * A0_1.strides[1]] + T.float32(2)
B_v1_1 = T.Buffer((n,), data=B_v1)
B_v1_1[j] = A0_2[i * A0_1.strides[0] + j * A0_1.strides[1]] * T.float32(3)
for j in range(n):
C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
A1_2 = T.Buffer((A1_1.strides[0] * m,), data=A1_1.data, buffer_type="auto")
C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A1_2[i * A1_1.strides[0] + j * A1_1.strides[1]] + B_v0_1[j]
Summary¶
This tutorial introduces the usage of tuple inputs operation.
Describe normal batchwise computation.
Describe reduction operation with tuple inputs.
Notice that you can only schedule computation in terms of operation instead of tensor.