Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Introduction to TOPI¶
Author: Ehsan M. Kermani
This is an introductory tutorial to TVM Operator Inventory (TOPI). TOPI provides numpy-style generic operations and schedules with higher abstractions than TVM. In this tutorial, we will see how TOPI can save us from writing boilerplate code in TVM.
import tvm
import tvm.testing
from tvm import te
from tvm import topi
import numpy as np
Basic example¶
Let’s revisit the sum of rows operation (equivalent to B = numpy.sum(A, axis=1)
’) To compute the sum of rows of a two dimensional TVM tensor A, we should
specify the symbolic operation as well as schedule as follows
and to examine the IR code in human readable format, we can do
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
A_1 = T.match_buffer(A, (n, m), strides=("stride", "stride"), buffer_type="auto")
B = T.allocate([n], "float32", "global")
for i in range(n):
B_1 = T.Buffer((n,), data=B)
B_1[i] = T.float32(0.0)
for k in range(m):
A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
B_1[i] = B_1[i] + A_2[i * A_1.strides[0] + k * A_1.strides[1]]
However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with
te.compute
. Imagine for more complicated operations how much details we need to provide.
Fortunately, we can replace those two lines with simple topi.sum
much like numpy.sum
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
A_1 = T.match_buffer(A, (n, m), strides=("stride", "stride"), buffer_type="auto")
A_red = T.allocate([n], "float32", "global")
for ax0 in range(n):
A_red_1 = T.Buffer((n,), data=A_red)
A_red_1[ax0] = T.float32(0.0)
for k1 in range(m):
A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
A_red_1[ax0] = A_red_1[ax0] + A_2[ax0 * A_1.strides[0] + k1 * A_1.strides[1]]
Numpy-style operator overloading¶
We can add two tensors using topi.broadcast_add
that have correct (broadcastable with specific) shapes.
Even shorter, TOPI provides operator overloading for such common operations. For example,
Overloaded with the same syntax, TOPI handles broadcasting a primitive (int, float) to a tensor d - 3.14
.
Generic schedules and fusing operations¶
Up to now, we have seen an example of how TOPI can save us from writing explicit computations in lower level API.
But it doesn’t stop here. Still we did the scheduling as before. TOPI also provides higher level
scheduling recipes depending on a given context. For example, for CUDA,
we can schedule the following series of operations ending with topi.sum
using only
topi.generic.schedule_reduce
/workspace/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(a: T.Buffer((100, 10, 10), "float32"), b: T.Buffer((10, 10), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
T_divide_red = T.allocate([1], "float32", "global")
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
T_divide_red_rf = T.allocate([1], "float32", "local")
reduce_temp0 = T.allocate([1], "float32", "local")
T_divide_red_rf_1 = T.Buffer((1,), data=T_divide_red_rf, scope="local", align=4)
T_divide_red_rf_1[0] = T.float32(0.0)
for k0_k1_fused_k2_fused_outer in range(10):
if T.likely(k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625 and k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625 and k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625):
a_1 = T.Buffer((10000,), data=a.data)
b_1 = T.Buffer((100,), data=b.data)
T_divide_red_rf_1[0] = T_divide_red_rf_1[0] + (a_1[k0_k1_fused_k2_fused_outer * 1024 + threadIdx_x] + b_1[(k0_k1_fused_k2_fused_outer * 24 + threadIdx_x) % 100] + a_1[k0_k1_fused_k2_fused_outer * 1024 + threadIdx_x] * b_1[(k0_k1_fused_k2_fused_outer * 24 + threadIdx_x) % 100]) * T.float32(0.5)
reduce_temp0_1 = T.Buffer((1,), data=reduce_temp0, scope="local")
with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
T.tvm_thread_allreduce(T.uint32(1), T_divide_red_rf_1[0], T.bool(True), reduce_temp0_1[0], threadIdx_x)
if threadIdx_x == 0:
T_divide_red_1 = T.Buffer((1,), data=T_divide_red, align=4)
T_divide_red_1[0] = reduce_temp0_1[0]
As you can see, scheduled stages of computation have been accumulated and we can examine them by
print(sg.stages)
[stage(a, placeholder(a, 0xb6f83c0)), stage(b, placeholder(b, 0x18229b30)), stage(T_add, compute(T_add, body=[a[ax0, ax1, ax2] + b[ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=broadcast, attrs={})), stage(T_multiply, compute(T_multiply, body=[a[ax0, ax1, ax2] * b[ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=broadcast, attrs={})), stage(T_elemwise_sum, compute(T_elemwise_sum, body=[T_add[ax0, ax1, ax2] + T_multiply[ax0, ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide, compute(T_divide, body=[T_elemwise_sum[ax0, ax1, ax2] / T.float32(2.0)], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide_red.rf, compute(T_divide_red.rf, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]), source=[T_divide[(k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 // 10, (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 % 10, (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) % 10]], init=[], axis=[T.iter_var(k0_k1_fused_k2_fused_outer, T.Range(0, 10), "CommReduce", "")], condition=T.likely((k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 // 10 < 100 and (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 < 1000 and k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024 < 10000), value_index=0)], axis=[T.iter_var(k0_k1_fused_k2_fused_inner, T.Range(0, 1024), "DataPar", "")], reduce_axis=[T.iter_var(k0_k1_fused_k2_fused_outer, T.Range(0, 10), "CommReduce", "")], tag=, attrs={})), stage(T_divide_red, compute(T_divide_red.repl, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]), source=[T_divide_red.rf[k0_k1_fused_k2_fused_inner_v]], init=[], axis=[T.iter_var(k0_k1_fused_k2_fused_inner_v, T.Range(0, 1024), "CommReduce", "")], condition=T.bool(True), value_index=0)], axis=[], reduce_axis=[T.iter_var(k0_k1_fused_k2_fused_inner_v, T.Range(0, 1024), "CommReduce", "")], tag=, attrs={}))]
We can test the correctness by comparing with numpy
result as follows
func = tvm.build(sg, [a, b, g], "cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)
a_nd = tvm.nd.array(a_np, dev)
b_nd = tvm.nd.array(b_np, dev)
g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev)
func(a_nd, b_nd, g_nd)
tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-5)
TOPI also provides common neural nets operations such as _softmax_ with optimized schedule
tarray = te.placeholder((512, 512), name="tarray")
softmax_topi = topi.nn.softmax(tarray)
with tvm.target.Target("cuda"):
sst = topi.cuda.schedule_softmax(softmax_topi)
print(tvm.lower(sst, [tarray], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(tarray: T.Buffer((512, 512), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
T_softmax_norm = T.allocate([65536], "float32x4", "global")
blockIdx_x = T.launch_thread("blockIdx.x", 512)
normal_reduce_temp0 = T.allocate([1], "float32", "local")
reduce_temp0 = T.allocate([1], "float32", "local")
T_softmax_exp = T.allocate([512], "float32", "warp")
normal_reduce_temp0_1 = T.allocate([1], "float32", "local")
reduce_temp0_1 = T.allocate([1], "float32", "local")
threadIdx_x = T.env_thread("threadIdx.x")
T_softmax_exp_1 = T.Buffer((512,), data=T_softmax_exp, scope="warp")
with T.launch_thread(threadIdx_x, 32):
normal_reduce_temp0_2 = T.Buffer((1,), data=normal_reduce_temp0, scope="local")
normal_reduce_temp0_2[0] = T.float32(-340282346638528859811704183484516925440.0)
tarray_1 = T.Buffer((262144,), data=tarray.data)
for k_inner in range(16):
normal_reduce_temp0_2[0] = T.max(normal_reduce_temp0_2[0], tarray_1[blockIdx_x * 512 + threadIdx_x * 16 + k_inner])
with T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-340282346638528859811704183484516925440.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0, scope="local")
T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0_2[0], T.bool(True), reduce_temp0_2[0], threadIdx_x)
for i1_inner_outer in range(4):
cse_var_1: T.int32 = i1_inner_outer * 4
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0, scope="local", align=4)
T_softmax_exp_1[threadIdx_x * 16 + cse_var_1:threadIdx_x * 16 + cse_var_1 + 4] = T.exp(tarray_1[blockIdx_x * 512 + threadIdx_x * 16 + cse_var_1:blockIdx_x * 512 + threadIdx_x * 16 + cse_var_1 + 4] - T.Broadcast(reduce_temp0_2[0], 4))
T.launch_thread(threadIdx_x, 32)
normal_reduce_temp0_2 = T.Buffer((1,), data=normal_reduce_temp0_1, scope="local")
normal_reduce_temp0_2[0] = T.float32(0.0)
for k_inner in range(16):
normal_reduce_temp0_2[0] = normal_reduce_temp0_2[0] + T_softmax_exp_1[threadIdx_x * 16 + k_inner]
with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0_1, scope="local")
T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0_2[0], T.bool(True), reduce_temp0_2[0], threadIdx_x)
for i1_inner_outer in range(4):
T_softmax_norm_1 = T.Buffer((65536,), "float32x4", data=T_softmax_norm)
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0_1, scope="local", align=4)
T_softmax_norm_1[blockIdx_x * 128 + threadIdx_x * 4 + i1_inner_outer] = T_softmax_exp_1[threadIdx_x * 16 + i1_inner_outer * 4:threadIdx_x * 16 + i1_inner_outer * 4 + 4] / T.Broadcast(reduce_temp0_2[0], 4)
Fusing convolutions¶
We can fuse topi.nn.conv2d
and topi.nn.relu
together.
Note
TOPI functions are all generic functions. They have different implementations for different backends to optimize for performance. For each backend, it is necessary to call them under a target scope for both compute declaration and schedule. TVM will choose the right function to call with the target information.
data = te.placeholder((1, 3, 224, 224))
kernel = te.placeholder((10, 3, 5, 5))
with tvm.target.Target("cuda"):
conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1)
out = topi.nn.relu(conv)
sconv = topi.cuda.schedule_conv2d_nchw([out])
print(tvm.lower(sconv, [data, kernel], simple_mode=True))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(placeholder: T.Buffer((1, 3, 224, 224), "float32"), placeholder_1: T.Buffer((10, 3, 5, 5), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
compute = T.allocate([501760], "float32", "global")
blockIdx_z = T.launch_thread("blockIdx.z", 5)
conv2d_nchw = T.allocate([14], "float32", "local")
pad_temp_shared = T.allocate([112], "float32", "shared")
placeholder_shared = T.allocate([2], "float32", "shared")
blockIdx_y = T.launch_thread("blockIdx.y", 224)
blockIdx_x = T.launch_thread("blockIdx.x", 2)
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 16)
conv2d_nchw_1 = T.Buffer((4,), data=conv2d_nchw, scope="local", align=8)
conv2d_nchw_1[0] = T.float32(0.0)
conv2d_nchw_1[2] = T.float32(0.0)
conv2d_nchw_1[4] = T.float32(0.0)
conv2d_nchw_1[6] = T.float32(0.0)
conv2d_nchw_1[8] = T.float32(0.0)
conv2d_nchw_1[10] = T.float32(0.0)
conv2d_nchw_1[12] = T.float32(0.0)
conv2d_nchw_1[1] = T.float32(0.0)
conv2d_nchw_1[3] = T.float32(0.0)
conv2d_nchw_1[5] = T.float32(0.0)
conv2d_nchw_1[7] = T.float32(0.0)
conv2d_nchw_1[9] = T.float32(0.0)
conv2d_nchw_1[11] = T.float32(0.0)
conv2d_nchw_1[13] = T.float32(0.0)
for rc_outer, ry_outer in T.grid(3, 5):
threadIdx_x_1 = T.env_thread("threadIdx.x")
pad_temp_shared_1 = T.Buffer((112,), data=pad_temp_shared, scope="shared")
placeholder_2 = T.Buffer((150528,), data=placeholder.data)
with T.launch_thread("threadIdx.z", 1) as threadIdx_z_1:
threadIdx_y_1 = T.launch_thread("threadIdx.y", 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + threadIdx_x_1 * 7 // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 450], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + (threadIdx_x_1 * 7 + 1) // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 449], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0.0))
threadIdx_x_2 = T.env_thread("threadIdx.x")
placeholder_shared_1 = T.Buffer((2,), data=placeholder_shared, scope="shared", align=8)
placeholder_3 = T.Buffer((750,), data=placeholder_1.data)
with T.launch_thread("threadIdx.z", 1) as threadIdx_z_1:
threadIdx_y_1 = T.launch_thread("threadIdx.y", 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
threadIdx_z_1 = T.env_thread("threadIdx.z")
threadIdx_y_1 = T.env_thread("threadIdx.y")
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + (threadIdx_x_1 * 7 + 1) // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 449], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0.0))
threadIdx_z_2 = T.env_thread("threadIdx.z")
threadIdx_y_2 = T.env_thread("threadIdx.y")
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 1]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0.0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 2]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 16 + threadIdx_x_1 < 31, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 441], T.float32(0.0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 3]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 16 + threadIdx_x_1 < 31, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 441], T.float32(0.0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 112 + threadIdx_x_1 * 7 < 216, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 440], T.float32(0.0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 4]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
compute_1 = T.Buffer((501760,), data=compute)
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x] = T.max(conv2d_nchw_1[0], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 16] = T.max(conv2d_nchw_1[2], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 32] = T.max(conv2d_nchw_1[4], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 48] = T.max(conv2d_nchw_1[6], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 64] = T.max(conv2d_nchw_1[8], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 80] = T.max(conv2d_nchw_1[10], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 96] = T.max(conv2d_nchw_1[12], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50176] = T.max(conv2d_nchw_1[1], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50192] = T.max(conv2d_nchw_1[3], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50208] = T.max(conv2d_nchw_1[5], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50224] = T.max(conv2d_nchw_1[7], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50240] = T.max(conv2d_nchw_1[9], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50256] = T.max(conv2d_nchw_1[11], T.float32(0.0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50272] = T.max(conv2d_nchw_1[13], T.float32(0.0))
Summary¶
In this tutorial, we have seen
How to use TOPI API for common operations with numpy-style operators.
How TOPI facilitates generic schedules and operator fusion for a context, to generate optimized kernel codes.