elementwise → reg
The reg variant lowers an elementwise op (sqrt, exp, add,
fma, …) when all operands are register (local) buffers. Like the copy
copy → reg variant the partition is induced by the operands’ register
layout — the thread axes are dropped, leaving each thread its private bundle — and
the op is applied to every register in that bundle. Source:
python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py.
What it accepts
is_reg_ewise(spec) builds the predicate:
def check(op_call, sctx):
if not sctx.is_target("cuda"): return False, "non-cuda target"
if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): ...
ok, reason = _all_threads_active(sctx)
plan, msg = spec.parse(op_call)
for br in buffer_regions(plan):
if br.buffer.scope() != "local": # every operand register
return False, f"operand scope {br.buffer.scope()} != local"
if br.buffer.layout is None: ...
# + spec.check_extras (dtype rules), pick_anchor + _validate_anchor_layout,
# _validate_scope_level_anchor
Property |
Requirement |
|---|---|
target / scope / priority |
|
operands |
every operand in |
op |
any registry op (unary / binary / |
register layout |
the anchor register layout must validate, and its thread axis must match the scope (it induces the partition) |
Demonstration program
A warp takes the elementwise sqrt of a 32×8 float32 register tile
(register layout S[(32,8):(1@laneid,1)] — lane i owns row i):
from tvm.tirx.layout import S, TileLayout, laneid
r_layout = TileLayout(S[(32, 8) : (1 @ laneid, 1)]); fs = (slice(0, 32), slice(0, 8))
@T.prim_func
def k(A_ptr: T.handle, B_ptr: T.handle):
A = T.match_buffer(A_ptr, (32, 8), "float32"); B = T.match_buffer(B_ptr, (32, 8), "float32")
T.device_entry(); T.cta_id([1]); T.lane_id([32]); tid = T.thread_id([32])
A_smem = T.alloc_buffer((32, 8), "float32", scope="shared", layout=TileLayout(S[(32, 8)]))
Tx.warp.copy(A_smem[fs], A[fs]); T.cuda.cta_sync()
R = T.alloc_buffer((32, 8), "float32", scope="local", layout=r_layout)
Tx.warp.copy(R[fs], A_smem[fs])
Tx.warp.sqrt(R[fs], R[fs]) # elementwise reg dispatch
Tx.warp.copy(A_smem[fs], R[fs]); T.cuda.cta_sync()
Tx.warp.copy(B[fs], A_smem[fs])
Algorithm
1. Parse and check. spec.parse builds the op plan; the predicate confirms
every operand is a register buffer and the anchor register layout is valid.
2. Induce the partition from the anchor’s thread axis (laneid here): drop
the thread iters, leaving each thread its private bundle (8 elements per lane).
3. Apply the op per register — a per-thread loop over the bundle:
Generated TIRx IR
buffer[f] = T.sqrt(buffer_1[f]) # over each register f in the lane's bundle
Generated CUDA
r_local_ptr[f_2] = sqrtf(r_local_ptr[f_2]); // per-register, private to the lane
(Verified on sm_100a — the result equals sqrt(A).)
How inputs change the algorithm
input |
effect |
|---|---|
op |
unary → |
dtype |
the register element type ( |
register layout |
the anchor’s thread axis sets the partition; a wider per-lane bundle means a longer loop |