copy_async → dsmem
The dsmem variant lowers a copy_async whose source and destination are
both shared memory but in different CTAs of a cluster (distributed shared
memory). One elected thread on the source CTA maps the destination CTA’s shared
address into its own address space (PTX mapa) and issues a bulk copy
(cp.async.bulk.shared::cluster); the hardware decrements the destination CTA’s
mbarrier when the bytes land. Source:
python/tvm/backend/cuda/operator/tile_primitive/copy_async/dsmem.py.
What it accepts
Three predicates: a valid copy, a single-thread scope, and a shared → shared pair:
# register_dispatch(..., priority=10, when=[
predicate("validate_copy_op", ...),
predicate("single_thread", lambda op, sctx: (single_thread(op, sctx), "expected single thread")),
predicate("is_shared_to_shared", lambda op, sctx: (_is_shared_to_shared(op), "not shared-to-shared")),
# ])
def _is_shared_to_shared(op_call):
src_scope = op_call.src.buffer.scope()
dst_scope = op_call.dst.buffer.scope()
return src_scope.startswith("shared") and dst_scope.startswith("shared")
Property |
Requirement |
|---|---|
target / priority |
|
scope |
single thread issues the copy (the source CTA elects one thread) |
memory pair |
both |
chunk size |
the contiguous chunk must be ≥ 16 bytes and a multiple of 16
( |
environment |
a cluster launch (so a remote CTA’s shared memory exists), plus a caller mbarrier on the destination CTA |
Demonstration program
A 2-CTA cluster: CTA 0 stages a 128×64 float16 tile global → its shared,
then bulk-copies it into CTA 1’s shared via dsmem; CTA 1 waits on the
mbarrier and writes the result out (from test_dsmem.py):
from tvm.tirx.lang.pipeline import MBarrier
shape, dtype, CLUSTER_N = (128, 64), "float16", 2
src_layout = dst_layout = TileLayout(S[128, 64])
copy_bytes = 128 * 64 * 2
r = (slice(0, 128), slice(0, 64))
@T.prim_func
def dsmem_copy(A_ptr: T.handle, B_ptr: T.handle):
A = T.match_buffer(A_ptr, shape, dtype); B = T.match_buffer(B_ptr, shape, dtype)
T.device_entry()
cbx = T.cta_id_in_cluster([CLUSTER_N]); T.cta_id([CLUSTER_N]); tid = T.thread_id([1])
pool = T.SMEMPool()
src_smem = T.decl_buffer(list(shape), dtype, pool.alloc([8192], dtype, align=128).data,
elem_offset=0, scope="shared.dyn", layout=src_layout)
dst_smem = T.decl_buffer(list(shape), dtype, pool.alloc([8192], dtype, align=128).data,
elem_offset=0, scope="shared.dyn", layout=dst_layout)
mbar = MBarrier(pool, 1); pool.commit()
mbar.init(1); T.ptx.fence.mbarrier_init(); T.cuda.cluster_sync()
if tid == 0:
if cbx == 0: # source CTA
Tx.copy(src_smem[r], A[r]) # global -> local shared
T.ptx.fence.proxy_async("shared::cta")
Tx.copy_async(dst_smem[r], src_smem[r], dispatch="dsmem",
mbar=mbar.ptr_to([0]), remote_cta_id=T.int32(1)) # -> CTA 1
else: # destination CTA
T.ptx.mbarrier.arrive.expect_tx(mbar.ptr_to([0]), copy_bytes)
mbar.wait(0, 0)
Tx.copy(B[r], dst_smem[r]) # remote shared -> global
T.cuda.cluster_sync()
Algorithm
1. Find the contiguous chunk. The dispatch slices and groups both layouts to the
copy region, walks inward to the longest matching contiguous stride-1 shard chain,
and multiplies those extents into chunk_elements; chunk_bytes must be ≥ 16
and a multiple of 16 (a cp.async.bulk constraint), else it declines:
chunk_bytes = chunk_elements * dtype_bytes
if chunk_bytes < 16 or chunk_bytes % 16 != 0:
fail(...)
2. Map the remote address. map_shared_rank (PTX mapa) translates a local
shared pointer into the destination CTA’s window — applied to both the destination
buffer pointer and the mbarrier:
remote_mbar = T.ptx.map_shared_rank(mbar, remote_cta_id)
cluster_dst = T.ptx.map_shared_rank(dst_buf.ptr_to(dst_st), remote_cta_id)
3. Issue one bulk copy per chunk. Fully contiguous → a single instruction; a strided region loops over the outer (non-contiguous) extents, re-deriving the chunk’s offsets each step:
if not outer_extents: # one contiguous chunk
T.ptx.cp_async.bulk.s2c(cluster_dst, src_buf.ptr_to(src_st), chunk_bytes, remote_mbar)
else:
for loop_vars in T.grid(*outer_extents): # one chunk per outer coord
... # re-decl src/dst views at the per-chunk offset
T.ptx.cp_async.bulk.s2c(cluster_dst, src_ptr, chunk_bytes, remote_mbar)
The complete_tx::bytes form makes the hardware decrement remote_mbar by
chunk_bytes on completion; the dispatch emits no wait — the caller arms the
mbarrier (arrive.expect_tx) and waits.
Generated TIRx IR
The fully contiguous 128×64 fp16 tile (16384 bytes) is a single chunk:
T.ptx.cp_async.bulk.s2c(cluster_dst[0], src_ptr[0], 16384, remote_mbar[0])
Generated CUDA
// map local shared addresses into CTA 1's window (mapa)
remote_mbar = tvm_builtin_ptx_mapa_u64(&mbar, /*rank=*/1); // asm: mapa.u64
cluster_dst = tvm_builtin_ptx_mapa_u64(&dst_smem, /*rank=*/1);
// bulk-copy 16384 bytes local shared -> CTA 1 shared, signalling its mbarrier
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes ..."
One thread on CTA 0 launches the whole 16 KB transfer; CTA 1’s mbarrier fires when it lands.
How inputs change the algorithm
input |
effect |
|---|---|
layout contiguity |
fully contiguous (matching row-major both sides) → one |
dtype / chunk size |
sets |
|
the |
incompatible layouts |
e.g. row-major source vs column-major destination → no matching contiguous
chain → the dispatch declines ( |