.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorial/intro_topi.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorial_intro_topi.py: .. _tutorial-topi: Introduction to TOPI ==================== **Author**: `Ehsan M. Kermani `_ This is an introductory tutorial to TVM Operator Inventory (TOPI). TOPI provides numpy-style generic operations and schedules with higher abstractions than TVM. In this tutorial, we will see how TOPI can save us from writing boilerplate code in TVM. .. GENERATED FROM PYTHON SOURCE LINES 28-36 .. code-block:: default import tvm import tvm.testing from tvm import te from tvm import topi import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 42-48 Basic example ------------- Let's revisit the sum of rows operation (equivalent to :code:`B = numpy.sum(A, axis=1)`') \ To compute the sum of rows of a two dimensional TVM tensor A, we should specify the symbolic operation as well as schedule as follows .. GENERATED FROM PYTHON SOURCE LINES 48-55 .. code-block:: default n = te.var("n") m = te.var("m") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) .. GENERATED FROM PYTHON SOURCE LINES 56-58 and to examine the IR code in human readable format, we can do .. GENERATED FROM PYTHON SOURCE LINES 58-60 .. code-block:: default print(tvm.lower(s, [A], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto")} buffer_map = {A_1: A} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_1: int32], type="auto")} { allocate(B: Pointer(global float32), float32, [n]), storage_scope = global; for (i: int32, 0, n) { B_1: Buffer(B, float32, [n], [])[i] = 0f32 for (k: int32, 0, m) { B_1[i] = (B_1[i] + A[((i*stride) + (k*stride_1))]) } } } .. GENERATED FROM PYTHON SOURCE LINES 61-65 However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with :code:`te.compute`. Imagine for more complicated operations how much details we need to provide. Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum` .. GENERATED FROM PYTHON SOURCE LINES 65-69 .. code-block:: default C = topi.sum(A, axis=1) ts = te.create_schedule(C.op) print(tvm.lower(ts, [A], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(A_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto")} buffer_map = {A_1: A} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_1: int32], type="auto")} { allocate(A_red: Pointer(global float32), float32, [n]), storage_scope = global; for (ax0: int32, 0, n) { A_red_1: Buffer(A_red, float32, [n], [])[ax0] = 0f32 for (k1: int32, 0, m) { A_red_1[ax0] = (A_red_1[ax0] + A[((ax0*stride) + (k1*stride_1))]) } } } .. GENERATED FROM PYTHON SOURCE LINES 70-75 Numpy-style operator overloading -------------------------------- We can add two tensors using :code:`topi.broadcast_add` that have correct (broadcastable with specific) shapes. Even shorter, TOPI provides operator overloading for such common operations. For example, .. GENERATED FROM PYTHON SOURCE LINES 75-81 .. code-block:: default x, y = 100, 10 a = te.placeholder((x, y, y), name="a") b = te.placeholder((y, y), name="b") c = a + b # same as topi.broadcast_add d = a * b # same as topi.broadcast_mul .. GENERATED FROM PYTHON SOURCE LINES 82-83 Overloaded with the same syntax, TOPI handles broadcasting a primitive (`int`, `float`) to a tensor :code:`d - 3.14`. .. GENERATED FROM PYTHON SOURCE LINES 85-93 Generic schedules and fusing operations --------------------------------------- Up to now, we have seen an example of how TOPI can save us from writing explicit computations in lower level API. But it doesn't stop here. Still we did the scheduling as before. TOPI also provides higher level scheduling recipes depending on a given context. For example, for CUDA, we can schedule the following series of operations ending with :code:`topi.sum` using only :code:`topi.generic.schedule_reduce` .. GENERATED FROM PYTHON SOURCE LINES 93-100 .. code-block:: default e = topi.elemwise_sum([c, d]) f = e / 2.0 g = topi.sum(f) with tvm.target.cuda(): sg = topi.cuda.schedule_reduce(g) print(tvm.lower(sg, [a, b], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none /workspace/python/tvm/target/target.py:389: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target. warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.") @main = primfn(a_1: handle, b_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {a: Buffer(a_2: Pointer(float32), float32, [10000], []), b: Buffer(b_2: Pointer(float32), float32, [100], [])} buffer_map = {a_1: a, b_1: b} preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [100, 10, 10], []), b_1: b_3: Buffer(b_2, float32, [10, 10], [])} { allocate(T_divide_red: Pointer(global float32), float32, [1]), storage_scope = global; attr [IterVar(threadIdx.x: int32, [0:1024], "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024; allocate(T_divide_red.rf: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local { T_divide_red.rf_1: Buffer(T_divide_red.rf, float32, [1], [], scope="local", align=4)[0] = 0f32 for (k0.k1.fused.k2.fused.outer: int32, 0, 10) { if @tir.likely((((((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625) && (((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625)) && (((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625)), dtype=bool) { T_divide_red.rf_1[0] = (T_divide_red.rf_1[0] + (((a[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)] + b[((floordiv(floormod(((k0.k1.fused.k2.fused.outer*12) + floordiv(threadIdx.x, 2)), 50), 5)*10) + floormod(((k0.k1.fused.k2.fused.outer*4) + threadIdx.x), 10))]) + (a[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)]*b[((floordiv(floormod(((k0.k1.fused.k2.fused.outer*12) + floordiv(threadIdx.x, 2)), 50), 5)*10) + floormod(((k0.k1.fused.k2.fused.outer*4) + threadIdx.x), 10))]))*0.5f32)) } } attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, T_divide_red.rf_1[0], True, reduce_temp0_1: Buffer(reduce_temp0, float32, [1], [], scope="local")[0], threadIdx.x, dtype=handle) if (threadIdx.x == 0) { T_divide_red_1: Buffer(T_divide_red, float32, [1], [], align=4)[0] = reduce_temp0_1[0] } } } .. GENERATED FROM PYTHON SOURCE LINES 101-103 As you can see, scheduled stages of computation have been accumulated and we can examine them by .. GENERATED FROM PYTHON SOURCE LINES 103-105 .. code-block:: default print(sg.stages) .. rst-class:: sphx-glr-script-out .. code-block:: none [stage(a, placeholder(a, 0x230b05e0)), stage(b, placeholder(b, 0x2196e6e0)), stage(T_add, compute(T_add, body=[(a[ax0, ax1, ax2] + b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_multiply, compute(T_multiply, body=[(a[ax0, ax1, ax2]*b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_elemwise_sum, compute(T_elemwise_sum, body=[(T_add[ax0, ax1, ax2] + T_multiply[ax0, ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide, compute(T_divide, body=[(T_elemwise_sum[ax0, ax1, ax2]/2f)], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide_red.rf, compute(T_divide_red.rf, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide[floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10)]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], where=tir.likely((((floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10) < 100) && (floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10) < 1000)) && ((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)) < 10000))), value_index=0)], axis=[iter_var(k0.k1.fused.k2.fused.inner, range(min=0, ext=1024))], reduce_axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], tag=, attrs={})), stage(T_divide_red, compute(T_divide_red.repl, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide_red.rf[k0.k1.fused.k2.fused.inner.v]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], where=(bool)1, value_index=0)], axis=[], reduce_axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], tag=, attrs={}))] .. GENERATED FROM PYTHON SOURCE LINES 106-108 We can test the correctness by comparing with :code:`numpy` result as follows .. GENERATED FROM PYTHON SOURCE LINES 108-119 .. code-block:: default func = tvm.build(sg, [a, b, g], "cuda") dev = tvm.cuda(0) a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype) b_np = np.random.uniform(size=(y, y)).astype(b.dtype) g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-5) .. GENERATED FROM PYTHON SOURCE LINES 120-122 TOPI also provides common neural nets operations such as _softmax_ with optimized schedule .. GENERATED FROM PYTHON SOURCE LINES 122-128 .. code-block:: default tarray = te.placeholder((512, 512), name="tarray") softmax_topi = topi.nn.softmax(tarray) with tvm.target.Target("cuda"): sst = topi.cuda.schedule_softmax(softmax_topi) print(tvm.lower(sst, [tarray], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(tarray_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {tarray: Buffer(tarray_2: Pointer(float32), float32, [262144], [])} buffer_map = {tarray_1: tarray} preflattened_buffer_map = {tarray_1: tarray_3: Buffer(tarray_2, float32, [512, 512], [])} { allocate(T_softmax_norm: Pointer(global float32x4), float32x4, [65536]), storage_scope = global; attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 512; allocate(normal_reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local; allocate(T_softmax_exp: Pointer(warp float32), float32, [512]), storage_scope = warp; allocate(normal_reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local; allocate(reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local { attr [IterVar(threadIdx.x: int32, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 { normal_reduce_temp0_2: Buffer(normal_reduce_temp0, float32, [1], [], scope="local")[0] = -3.40282e+38f32 for (k.inner: int32, 0, 16) { normal_reduce_temp0_2[0] = max(normal_reduce_temp0_2[0], tarray[(((blockIdx.x*512) + (threadIdx.x*16)) + k.inner)]) } attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, normal_reduce_temp0_2[0], True, reduce_temp0_2: Buffer(reduce_temp0, float32, [1], [], scope="local")[0], threadIdx.x, dtype=handle) for (i1.inner.outer: int32, 0, 4) { let cse_var_1: int32 = (i1.inner.outer*4) T_softmax_exp_1: Buffer(T_softmax_exp, float32, [512], [], scope="warp")[ramp(((threadIdx.x*16) + cse_var_1), 1, 4)] = @tir.exp((tarray[ramp((((blockIdx.x*512) + (threadIdx.x*16)) + cse_var_1), 1, 4)] - broadcast(reduce_temp0_3: Buffer(reduce_temp0, float32, [1], [], scope="local", align=4)[0], 4)), dtype=float32x4) } } attr [IterVar(threadIdx.x, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 { normal_reduce_temp0_3: Buffer(normal_reduce_temp0_1, float32, [1], [], scope="local")[0] = 0f32 for (k.inner_1: int32, 0, 16) { normal_reduce_temp0_3[0] = (normal_reduce_temp0_3[0] + T_softmax_exp_1[((threadIdx.x*16) + k.inner_1)]) } attr [meta[tir.CommReducer][1]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle); @tir.tvm_thread_allreduce(1u32, normal_reduce_temp0_3[0], True, reduce_temp0_4: Buffer(reduce_temp0_1, float32, [1], [], scope="local")[0], threadIdx.x, dtype=handle) for (i1.inner.outer_1: int32, 0, 4) { T_softmax_norm_1: Buffer(T_softmax_norm, float32x4, [65536], [])[(((blockIdx.x*128) + (threadIdx.x*4)) + i1.inner.outer_1)] = (T_softmax_exp_1[ramp(((threadIdx.x*16) + (i1.inner.outer_1*4)), 1, 4)] / broadcast(reduce_temp0_5: Buffer(reduce_temp0_1, float32, [1], [], scope="local", align=4)[0], 4)) } } } } .. GENERATED FROM PYTHON SOURCE LINES 129-140 Fusing convolutions ------------------- We can fuse :code:`topi.nn.conv2d` and :code:`topi.nn.relu` together. .. note:: TOPI functions are all generic functions. They have different implementations for different backends to optimize for performance. For each backend, it is necessary to call them under a target scope for both compute declaration and schedule. TVM will choose the right function to call with the target information. .. GENERATED FROM PYTHON SOURCE LINES 140-150 .. code-block:: default data = te.placeholder((1, 3, 224, 224)) kernel = te.placeholder((10, 3, 5, 5)) with tvm.target.Target("cuda"): conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1) out = topi.nn.relu(conv) sconv = topi.cuda.schedule_conv2d_nchw([out]) print(tvm.lower(sconv, [data, kernel], simple_mode=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none @main = primfn(placeholder_2: handle, placeholder_3: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {placeholder: Buffer(placeholder_4: Pointer(float32), float32, [150528], []), placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [750], [])} buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1} preflattened_buffer_map = {placeholder_2: placeholder_6: Buffer(placeholder_4, float32, [1, 3, 224, 224], []), placeholder_3: placeholder_7: Buffer(placeholder_5, float32, [10, 3, 5, 5], [])} { allocate(compute: Pointer(global float32), float32, [501760]), storage_scope = global; attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 5; allocate(conv2d_nchw: Pointer(local float32), float32, [14]), storage_scope = local; allocate(pad_temp.shared: Pointer(shared float32), float32, [112]), storage_scope = shared; allocate(placeholder.shared: Pointer(shared float32), float32, [2]), storage_scope = shared; attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 224; attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2; attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { conv2d_nchw_1: Buffer(conv2d_nchw, float32, [4], [], scope="local", align=8)[0] = 0f32 conv2d_nchw_1[2] = 0f32 conv2d_nchw_1[4] = 0f32 conv2d_nchw_1[6] = 0f32 conv2d_nchw_1[8] = 0f32 conv2d_nchw_1[10] = 0f32 conv2d_nchw_1[12] = 0f32 conv2d_nchw_1[1] = 0f32 conv2d_nchw_1[3] = 0f32 conv2d_nchw_1[5] = 0f32 conv2d_nchw_1[7] = 0f32 conv2d_nchw_1[9] = 0f32 conv2d_nchw_1[11] = 0f32 conv2d_nchw_1[13] = 0f32 for (rc.outer: int32, 0, 3) { for (ry.outer: int32, 0, 5) { attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared_1: Buffer(pad_temp.shared, float32, [112], [], scope="shared")[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv((threadIdx.x_1*7), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 450)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 1), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared_1: Buffer(placeholder.shared, float32, [2], [], scope="shared", align=8)[threadIdx.x_2] = placeholder_1[((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5))] } conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0])) conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0])) conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0])) conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0])) conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0])) conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0])) conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0])) conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1])) conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1])) conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1])) conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1])) conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1])) conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1])) conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 1), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 1)] } conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0])) conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0])) conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0])) conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0])) conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0])) conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0])) conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0])) conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1])) conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1])) conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1])) conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1])) conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1])) conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1])) conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 2)] } conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0])) conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0])) conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0])) conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0])) conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0])) conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0])) conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0])) conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1])) conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1])) conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1])) conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1])) conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1])) conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1])) conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 9), 2)) < 113)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 3)] } conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0])) conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0])) conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0])) conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0])) conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0])) conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0])) conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0])) conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1])) conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1])) conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1])) conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1])) conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1])) conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1])) conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1])) attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16 { pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 9), 2)) < 113)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32) pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv((threadIdx.x_1*7), 2)) < 108)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 440)], 0f32, dtype=float32) } attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1; attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1; attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16; if @tir.likely((threadIdx.x_2 < 2), dtype=bool) { placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 4)] } conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0])) conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0])) conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0])) conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0])) conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0])) conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0])) conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0])) conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1])) conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1])) conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1])) conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1])) conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1])) conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1])) conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1])) } } compute_1: Buffer(compute, float32, [501760], [])[((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x)] = max(conv2d_nchw_1[0], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 16)] = max(conv2d_nchw_1[2], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 32)] = max(conv2d_nchw_1[4], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 48)] = max(conv2d_nchw_1[6], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 64)] = max(conv2d_nchw_1[8], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 80)] = max(conv2d_nchw_1[10], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 96)] = max(conv2d_nchw_1[12], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50176)] = max(conv2d_nchw_1[1], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50192)] = max(conv2d_nchw_1[3], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50208)] = max(conv2d_nchw_1[5], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50224)] = max(conv2d_nchw_1[7], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50240)] = max(conv2d_nchw_1[9], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50256)] = max(conv2d_nchw_1[11], 0f32) compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50272)] = max(conv2d_nchw_1[13], 0f32) } } .. GENERATED FROM PYTHON SOURCE LINES 151-157 Summary ------- In this tutorial, we have seen - How to use TOPI API for common operations with numpy-style operators. - How TOPI facilitates generic schedules and operator fusion for a context, to generate optimized kernel codes. .. _sphx_glr_download_tutorial_intro_topi.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: intro_topi.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: intro_topi.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_