elementwise → smem

The smem variant lowers an elementwise op (sqrt, exp, add, fma, …) when all operands are in shared memory. Like the copy copy → gmem_smem variant it synthesizes a [outer, threads, vec] partition from the execution scope, then applies the op to each (vectorized) element. Source: python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py.

What it accepts

is_smem_ewise(spec) builds the predicate:

def check(op_call, sctx):
    if not sctx.is_target("cuda"): return False, "non-cuda target"
    if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): ...
    ok, reason = _all_threads_active(sctx)              # full scope
    plan, msg = spec.parse(op_call)                     # parse the op's operands
    for br in buffer_regions(plan):
        if not br.buffer.scope().startswith("shared"):  # every operand shared*
            return False, f"operand scope {br.buffer.scope()} != shared*"
        if br.buffer.layout is None: ...
    # + spec.check_extras (dtype rules) and anchor-layout validation

Property

Requirement

target / scope / priority

cuda; thread / warp / warpgroup / cta (all active); priority 10

operands

every operand (inputs and output) in shared*

op

any op in the registry (unary sqrt/exp/zero…, binary add/mul…, fma); spec.check_extras validates the dtype combo

layout

operands have layouts; the layout sets the vector width (the partition itself is synthesized from the scope’s thread count, not the layout)

Demonstration program

A CTA takes the elementwise sqrt of a 32×32 float32 shared tile (adapted from test_unary.py — here a 256-thread CTA, so the partition is one round):

s_layout = TileLayout(S[(32, 32)]); full = (slice(0, 32), slice(0, 32))

@T.prim_func
def unary_op(A_ptr: T.handle):
    A = T.match_buffer(A_ptr, (32, 32), "float32", layout=s_layout)
    T.device_entry(); T.cta_id([1]); T.warp_id([8]); T.lane_id([32]); T.thread_id([256])
    A_smem = T.alloc_buffer((32, 32), "float32", scope="shared", layout=s_layout)
    Tx.cta.copy(A_smem[full], A[full])
    Tx.cta.sqrt(A_smem[full], A_smem[full])   # elementwise smem dispatch
    Tx.cta.copy(A[full], A_smem[full])

Algorithm

1. Parse the op and check operands. spec.parse turns the call into a plan (inputs, output, the op); the predicate confirms every operand is shared.

2. Synthesize the partition from the scope’s thread count (as copy → gmem_smem does): split the region into [outer, threads, vec], with the vector width taken from the layout’s innermost contiguous run. For 32×32 = 1024 float32 over 256 threads, vec = 4outer = 1.

3. Apply the op per element. Instead of a copy, each (thread, round) reads its vec elements, applies the op, and writes back — vectorized:

Generated TIRx IR

for f in range(1):                                # outer = 1
    A_smem[tid * 4 + vec] = T.sqrt(A_smem[tid * 4 + vec])

Generated CUDA

The vec = 4 element bundle becomes a float4 and the op is applied per component:

float4 v_ = *(float4*)(&A_smem_ptr[tid * 4]);
__1.x = sqrtf(v_.x);  __1.y = sqrtf(v_.y);
__1.z = sqrtf(v_.z);  __1.w = sqrtf(v_.w);

(Verified on sm_100a — the tile equals sqrt(A).)

How inputs change the algorithm

input

effect

op

unary → sqrtf / expf / … per component; binary → the two inputs combined (a + b); fmaa * b + c

dtype

sets the vector width (vec = widest aligned ⇒ the round count), as in copy → gmem_smem

scope

sets the thread axis and count, hence the synthesized partition