.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorial_intro_topi.py: .. _tutorial-topi: Introduction to TOPI ==================== **Author**: `Ehsan M. Kermani `_ This is an introductory tutorial to TVM Operator Inventory (TOPI). TOPI provides numpy-style generic operations and schedules with higher abstractions than TVM. In this tutorial, we will see how TOPI can save us from writing boilerplates code in TVM. .. code-block:: default from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te from tvm import topi import numpy as np Basic example ------------- Let's revisit the sum of rows operation (equivalent to :code:`B = numpy.sum(A, axis=1)`') \ To compute the sum of rows of a two dimensional TVM tensor A, we should specify the symbolic operation as well as schedule as follows .. code-block:: default n = te.var("n") m = te.var("m") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) and to examine the IR code in human readable format, we can do .. code-block:: default print(tvm.lower(s, [A], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(A_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [n: int32, m: int32], [stride: int32, stride_1: int32], type="auto")} buffer_map = {A_1: A} { allocate(B: Pointer(global float32), float32, [n]), storage_scope = global; for (i: int32, 0, n) { B[i] = 0f32 for (k: int32, 0, m) { B[i] = ((float32*)B[i] + (float32*)A_2[((i*stride) + (k*stride_1))]) } } } However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with :code:`te.compute`. Imagine for more complicated operations how much details we need to provide. Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum` .. code-block:: default C = topi.sum(A, axis=1) ts = te.create_schedule(C.op) print(tvm.lower(ts, [A], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(A_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [n: int32, m: int32], [stride: int32, stride_1: int32], type="auto")} buffer_map = {A_1: A} { allocate(A_red: Pointer(global float32), float32, [n]), storage_scope = global; for (ax0: int32, 0, n) { A_red[ax0] = 0f32 for (k1: int32, 0, m) { A_red[ax0] = ((float32*)A_red[ax0] + (float32*)A_2[((ax0*stride) + (k1*stride_1))]) } } } Numpy-style operator overloading -------------------------------- We can add two tensors using :code:`topi.broadcast_add` that have correct (broadcastable with specific) shapes. Even shorter, TOPI provides operator overloading for such common operations. For example, .. code-block:: default x, y = 100, 10 a = te.placeholder((x, y, y), name="a") b = te.placeholder((y, y), name="b") c = a + b # same as topi.broadcast_add d = a * b # same as topi.broadcast_mul Overloaded with the same syntax, TOPI handles broadcasting a primitive (`int`, `float`) to a tensor :code:`d - 3.14`. Generic schedules and fusing operations --------------------------------------- Up to now, we have seen an example of how TOPI can save us from writing explicit computations in lower level API. But it doesn't stop here. Still we did the scheduling as before. TOPI also provides higher level scheduling recipes depending on a given context. For example, for CUDA, we can schedule the following series of operations ending with :code:`topi.sum` using only :code:`topi.generic.schedule_reduce` .. code-block:: default e = topi.elemwise_sum([c, d]) f = e / 2.0 g = topi.sum(f) with tvm.target.cuda(): sg = topi.cuda.schedule_reduce(g) print(tvm.lower(sg, [a, b], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(a_1: handle, b_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {b: Buffer(b_2: Pointer(float32), float32, [10, 10], []), a: Buffer(a_2: Pointer(float32), float32, [100, 10, 10], [])} buffer_map = {a_1: a, b_1: b} { allocate(T_divide_red: Pointer(global float32), float32, [1]), storage_scope = global; attr [IterVar(threadIdx.x: int32, [0:1024], "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024; allocate(T_divide_red.rf: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local { T_divide_red.rf[0] = 0f32 for (k0.k1.fused.k2.fused.outer: int32, 0, 10) { if @tir.likely((((((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x) < 10000) && (((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x) < 10000)) && (((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x) < 10000)), dtype=bool) { T_divide_red.rf[0] = ((float32*)T_divide_red.rf[0] + ((((float32*)a_2[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)] + (float32*)b_2[floormod(((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x), 100)]) + ((float32*)a_2[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)]*(float32*)b_2[floormod(((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x), 100)]))*0.5f32)) } } attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, (float32*)T_divide_red.rf[0], True, reduce_temp0, threadIdx.x, dtype=handle) if (threadIdx.x == 0) { T_divide_red[0] = (float32*)reduce_temp0[0] } } } As you can see, scheduled stages of computation have been accumulated and we can examine them by .. code-block:: default print(sg.stages) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [stage(a, placeholder(a, 0x4d31620)), stage(b, placeholder(b, 0x11fa0220)), stage(T_add, compute(T_add, body=[(a[ax0, ax1, ax2] + b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_multiply, compute(T_multiply, body=[(a[ax0, ax1, ax2]*b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_elemwise_sum, compute(T_elemwise_sum, body=[(T_add[ax0, ax1, ax2] + T_multiply[ax0, ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide, compute(T_divide, body=[(T_elemwise_sum[ax0, ax1, ax2]/2f)], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide_red.rf, compute(T_divide_red.rf, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide[floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10)]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], where=tir.likely((((floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10) < 100) && (floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10) < 1000)) && ((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)) < 10000))), value_index=0)], axis=[iter_var(k0.k1.fused.k2.fused.inner, range(min=0, ext=1024))], reduce_axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], tag=, attrs={})), stage(T_divide_red, compute(T_divide_red.repl, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide_red.rf[k0.k1.fused.k2.fused.inner.v]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], where=(bool)1, value_index=0)], axis=[], reduce_axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], tag=, attrs={}))] We can test the correctness by comparing with :code:`numpy` result as follows .. code-block:: default func = tvm.build(sg, [a, b, g], "cuda") dev = tvm.cuda(0) a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype) b_np = np.random.uniform(size=(y, y)).astype(b.dtype) g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-5) TOPI also provides common neural nets operations such as _softmax_ with optimized schedule .. code-block:: default tarray = te.placeholder((512, 512), name="tarray") softmax_topi = topi.nn.softmax(tarray) with tvm.target.Target("cuda"): sst = topi.cuda.schedule_softmax(softmax_topi) print(tvm.lower(sst, [tarray], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(tarray_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {tarray: Buffer(tarray_2: Pointer(float32), float32, [512, 512], [])} buffer_map = {tarray_1: tarray} { allocate(T_softmax_norm: Pointer(global float32x4), float32x4, [65536]), storage_scope = global; attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 512; allocate(normal_reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local; allocate(T_softmax_exp: Pointer(warp float32), float32, [512]), storage_scope = warp; allocate(normal_reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local { attr [IterVar(threadIdx.x: int32, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 { normal_reduce_temp0[0] = -3.40282e+38f32 for (k.inner: int32, 0, 16) { normal_reduce_temp0[0] = max((float32*)normal_reduce_temp0[0], (float32*)tarray_2[(((blockIdx.x*512) + (threadIdx.x*16)) + k.inner)]) } attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, (float32*)normal_reduce_temp0[0], True, reduce_temp0, threadIdx.x, dtype=handle) for (i1.inner.outer: int32, 0, 4) { T_softmax_exp[ramp(((threadIdx.x*16) + (i1.inner.outer*4)), 1, 4)] = @tir.exp(((float32x4*)tarray_2[ramp((((blockIdx.x*512) + (threadIdx.x*16)) + (i1.inner.outer*4)), 1, 4)] - broadcast((float32*)reduce_temp0[0], 4)), dtype=float32x4) } } attr [IterVar(threadIdx.x, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 { normal_reduce_temp0_1[0] = 0f32 for (k.inner_1: int32, 0, 16) { normal_reduce_temp0_1[0] = ((float32*)normal_reduce_temp0_1[0] + (float32*)T_softmax_exp[((threadIdx.x*16) + k.inner_1)]) } attr [meta[tir.CommReducer][1]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, (float32*)normal_reduce_temp0_1[0], True, reduce_temp0_1, threadIdx.x, dtype=handle) for (i1.inner.outer_1: int32, 0, 4) { T_softmax_norm[ramp((((blockIdx.x*512) + (threadIdx.x*16)) + (i1.inner.outer_1*4)), 1, 4)] = ((float32x4*)T_softmax_exp[ramp(((threadIdx.x*16) + (i1.inner.outer_1*4)), 1, 4)] / broadcast((float32*)reduce_temp0_1[0], 4)) } } } } Fusing convolutions ------------------- We can fuse :code:`topi.nn.conv2d` and :code:`topi.nn.relu` together. .. note:: TOPI functions are all generic functions. They have different implementations for different backends to optimize for performance. For each backend, it is necessary to call them under a target scope for both compute declaration and schedule. TVM will choose the right function to call with the target information. .. code-block:: default data = te.placeholder((1, 3, 224, 224)) kernel = te.placeholder((10, 3, 5, 5)) with tvm.target.Target("cuda"): conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1) out = topi.nn.relu(conv) sconv = topi.cuda.schedule_conv2d_nchw([out]) print(tvm.lower(sconv, [data, kernel], simple_mode=True)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none primfn(placeholder_2: handle, placeholder_3: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {placeholder_1: Buffer(placeholder_4: Pointer(float32), float32, [10, 3, 5, 5], []), placeholder: Buffer(placeholder_5: Pointer(float32), float32, [1, 3, 224, 224], [])} buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1} { allocate(compute: Pointer(global float32), float32, [501760]), storage_scope = global; attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 5; allocate(compute_1: Pointer(local float32), float32, [14]), storage_scope = local; allocate(pad_temp.shared: Pointer(shared float32), float32, [112]), storage_scope = shared; allocate(placeholder.shared: Pointer(shared float32), float32, [2]), storage_scope = shared; attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 224; attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2; attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { compute_1[0] = 0f32 compute_1[2] = 0f32 compute_1[4] = 0f32 compute_1[6] = 0f32 compute_1[8] = 0f32 compute_1[10] = 0f32 compute_1[12] = 0f32 compute_1[1] = 0f32 compute_1[3] = 0f32 compute_1[5] = 0f32 compute_1[7] = 0f32 compute_1[9] = 0f32 compute_1[11] = 0f32 compute_1[13] = 0f32 for (rc.outer: int32, 0, 3) { for (ry.outer: int32, 0, 5) { attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (2 <= ((blockIdx.x*112) + (threadIdx.x_1*7)))), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 450)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 1)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*112) + (threadIdx.x_1*7)))), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared[threadIdx.x_2] = (float32*)placeholder_4[((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5))] } compute_1[0] = ((float32*)compute_1[0] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[0])) compute_1[2] = ((float32*)compute_1[2] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[0])) compute_1[4] = ((float32*)compute_1[4] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[0])) compute_1[6] = ((float32*)compute_1[6] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[0])) compute_1[8] = ((float32*)compute_1[8] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[0])) compute_1[10] = ((float32*)compute_1[10] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[0])) compute_1[12] = ((float32*)compute_1[12] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[0])) compute_1[1] = ((float32*)compute_1[1] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[1])) compute_1[3] = ((float32*)compute_1[3] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[1])) compute_1[5] = ((float32*)compute_1[5] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[1])) compute_1[7] = ((float32*)compute_1[7] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[1])) compute_1[9] = ((float32*)compute_1[9] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[1])) compute_1[11] = ((float32*)compute_1[11] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[1])) compute_1[13] = ((float32*)compute_1[13] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*112) + (threadIdx.x_1*7)))), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared[threadIdx.x_2] = (float32*)placeholder_4[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 1)] } compute_1[0] = ((float32*)compute_1[0] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[0])) compute_1[2] = ((float32*)compute_1[2] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[0])) compute_1[4] = ((float32*)compute_1[4] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[0])) compute_1[6] = ((float32*)compute_1[6] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[0])) compute_1[8] = ((float32*)compute_1[8] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[0])) compute_1[10] = ((float32*)compute_1[10] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[0])) compute_1[12] = ((float32*)compute_1[12] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[0])) compute_1[1] = ((float32*)compute_1[1] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[1])) compute_1[3] = ((float32*)compute_1[3] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[1])) compute_1[5] = ((float32*)compute_1[5] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[1])) compute_1[7] = ((float32*)compute_1[7] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[1])) compute_1[9] = ((float32*)compute_1[9] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[1])) compute_1[11] = ((float32*)compute_1[11] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[1])) compute_1[13] = ((float32*)compute_1[13] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared[threadIdx.x_2] = (float32*)placeholder_4[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 2)] } compute_1[0] = ((float32*)compute_1[0] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[0])) compute_1[2] = ((float32*)compute_1[2] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[0])) compute_1[4] = ((float32*)compute_1[4] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[0])) compute_1[6] = ((float32*)compute_1[6] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[0])) compute_1[8] = ((float32*)compute_1[8] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[0])) compute_1[10] = ((float32*)compute_1[10] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[0])) compute_1[12] = ((float32*)compute_1[12] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[0])) compute_1[1] = ((float32*)compute_1[1] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[1])) compute_1[3] = ((float32*)compute_1[3] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[1])) compute_1[5] = ((float32*)compute_1[5] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[1])) compute_1[7] = ((float32*)compute_1[7] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[1])) compute_1[9] = ((float32*)compute_1[9] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[1])) compute_1[11] = ((float32*)compute_1[11] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[1])) compute_1[13] = ((float32*)compute_1[13] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*112) + (threadIdx.x_1*7)) < 217)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared[threadIdx.x_2] = (float32*)placeholder_4[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 3)] } compute_1[0] = ((float32*)compute_1[0] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[0])) compute_1[2] = ((float32*)compute_1[2] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[0])) compute_1[4] = ((float32*)compute_1[4] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[0])) compute_1[6] = ((float32*)compute_1[6] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[0])) compute_1[8] = ((float32*)compute_1[8] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[0])) compute_1[10] = ((float32*)compute_1[10] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[0])) compute_1[12] = ((float32*)compute_1[12] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[0])) compute_1[1] = ((float32*)compute_1[1] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[1])) compute_1[3] = ((float32*)compute_1[3] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[1])) compute_1[5] = ((float32*)compute_1[5] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[1])) compute_1[7] = ((float32*)compute_1[7] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[1])) compute_1[9] = ((float32*)compute_1[9] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[1])) compute_1[11] = ((float32*)compute_1[11] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[1])) compute_1[13] = ((float32*)compute_1[13] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 5)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*112) + (threadIdx.x_1*7)) < 217)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32) pad_temp.shared[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*112) + (threadIdx.x_1*7)) < 216)), (float32*)placeholder_5[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 440)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared[threadIdx.x_2] = (float32*)placeholder_4[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 4)] } compute_1[0] = ((float32*)compute_1[0] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[0])) compute_1[2] = ((float32*)compute_1[2] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[0])) compute_1[4] = ((float32*)compute_1[4] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[0])) compute_1[6] = ((float32*)compute_1[6] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[0])) compute_1[8] = ((float32*)compute_1[8] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[0])) compute_1[10] = ((float32*)compute_1[10] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[0])) compute_1[12] = ((float32*)compute_1[12] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[0])) compute_1[1] = ((float32*)compute_1[1] + ((float32*)pad_temp.shared[threadIdx.x]*(float32*)placeholder.shared[1])) compute_1[3] = ((float32*)compute_1[3] + ((float32*)pad_temp.shared[(threadIdx.x + 16)]*(float32*)placeholder.shared[1])) compute_1[5] = ((float32*)compute_1[5] + ((float32*)pad_temp.shared[(threadIdx.x + 32)]*(float32*)placeholder.shared[1])) compute_1[7] = ((float32*)compute_1[7] + ((float32*)pad_temp.shared[(threadIdx.x + 48)]*(float32*)placeholder.shared[1])) compute_1[9] = ((float32*)compute_1[9] + ((float32*)pad_temp.shared[(threadIdx.x + 64)]*(float32*)placeholder.shared[1])) compute_1[11] = ((float32*)compute_1[11] + ((float32*)pad_temp.shared[(threadIdx.x + 80)]*(float32*)placeholder.shared[1])) compute_1[13] = ((float32*)compute_1[13] + ((float32*)pad_temp.shared[(threadIdx.x + 96)]*(float32*)placeholder.shared[1])) } } compute[((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x)] = max((float32*)compute_1[0], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 16)] = max((float32*)compute_1[2], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 32)] = max((float32*)compute_1[4], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 48)] = max((float32*)compute_1[6], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 64)] = max((float32*)compute_1[8], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 80)] = max((float32*)compute_1[10], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 96)] = max((float32*)compute_1[12], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50176)] = max((float32*)compute_1[1], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50192)] = max((float32*)compute_1[3], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50208)] = max((float32*)compute_1[5], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50224)] = max((float32*)compute_1[7], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50240)] = max((float32*)compute_1[9], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50256)] = max((float32*)compute_1[11], 0f32) compute[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50272)] = max((float32*)compute_1[13], 0f32) } } Summary ------- In this tutorial, we have seen - How to use TOPI API for common operations with numpy-style operators. - How TOPI facilitates generic schedules and operator fusion for a context, to generate optimized kernel codes. .. _sphx_glr_download_tutorial_intro_topi.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: intro_topi.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: intro_topi.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_