Tile Primitives
Note
This page documents the tile-primitive surface and dispatch as it exists in the source today; signatures and variants may change.
Tile primitives are the dispatchable, hardware-level operations a TIRx kernel
issues — data movement (copy, copy_async), matrix multiply (gemm,
gemm_async), reductions, elementwise math, and a few fused/compose forms.
A primitive call is recorded as an unresolved TilePrimitiveCall IR node;
the compiler later dispatches it — selecting a concrete lowering from the
primitive, the execution scope, the operand layouts, the target, and an optional
explicit hint — and replaces it with native IR (loops, address arithmetic,
synchronization, and backend intrinsics).
Calling convention
Tile primitives are called in TVMScript via from tvm.script import tirx as T,
on the injected Tx namespace. The namespace prefix selects the cooperation
scope:
Tx.<name>(...)— unqualified, runs at thread scope.Tx.warp.<name>/Tx.wg.<name>(aliasTx.warpgroup) /Tx.cta.<name>/Tx.cluster.<name>/Tx.thread.<name>— bind a wider scope.
Every primitive also accepts, besides its operands: scope (usually set by the
namespace), workspace: dict[str, Buffer] | None, dispatch: str | None
(force a named lowering variant), and **kwargs collected into a config
dict that tunes the chosen lowering. Operands are Buffer / BufferRegion
values, each carrying a TileLayout that dispatch reads.
Wiring (three layers): the authoritative op list is the C++ registry
(src/tirx/op/tirx.cc, 29 ops named tirx.tile.<name>); the IR wrapper
classes are in python/tvm/tirx/operator/tile_primitive/ops.py; the
user-facing Tx.* builders are in python/tvm/tirx/script/builder/tirx.py.
Primitive catalog
The 29 primitives, grouped. Signatures show the operands plus the common
workspace/dispatch/scope/**kwargs tail (abbreviated ...).
Data movement
copy(dst, src, ...) # synchronous element copy src -> dst
copy_async(dst, src, ...) # asynchronous copy (caller commits/waits)
permute_layout(dst, src, ...) # rearrange under a different layout (may alias)
Matrix multiply
gemm(D, A, B, C, transpose_A=False, transpose_B=False,
alpha=1.0, beta=0.0, ...) # D = alpha*A*B + beta*C (register mma)
gemm_async(C, A, B, SFA=None, SFB=None,
transA=False, transB=False, accum=False, ...) # async / block-scaled
Fill / memset / zero
fill(dst, value, ...) # fill region with a scalar
memset(dst, value, ...) # set all elements to a value
zero(dst, src=None, ...) # zero out (in place if src omitted)
Cast and elementwise
cast(dst, src=None, ...) # dtype cast (buffer form)
sqrt / exp / exp2(dst, src=None, bias=None, scale=None, ...)
reciprocal(dst, src=None, ...) # dst = 1/src
silu(dst, src, ...) # dst = src*sigmoid(src)
add / sub / mul / fdiv(dst, src1, src2, ...) # element-wise arithmetic
maximum / minimum(dst, src1, src2, ...) # element-wise max / min
fma(dst, src, scale, bias, ...) # dst = src*scale + bias
select(dst, true_value, false_value, pred, scope=None) # dst = pred ? t : f
Reductions
sum / max / min(dst, src, axes=-1, accum=False, ...) # reduce over axes
Fused / compose
binary_reduce(...) # binary op then reduce, fused
unary_reduce(...) # unary (with bias/scale) then reduce
binary_chain(...) # chain two binary ops
reduce_negate(...) # reduce then negate
compose_op(...) # frame/context manager to group primitives
Dispatch config
A call is materialized as a TilePrimitiveCall node whose fields carry
everything dispatch needs (python/tvm/tirx/stmt.py):
Field |
Type |
Meaning |
|---|---|---|
|
|
primitive identity, e.g. |
|
|
operands (regions / scalars), in the order shown above |
|
|
pre-allocated scratch buffers |
|
|
open-ended tuning bag (table below) |
|
|
forced variant name; |
|
|
cooperation scope (default |
config has no central schema — each key is read only by the dispatch
variant(s) that need it (via config.get(...)); a key meant for another
primitive is simply ignored. Only dispatch is generic. The keys observed in
the CUDA backend, by consumer:
Key |
Used by |
Type / values |
Meaning |
|---|---|---|---|
|
any primitive |
variant name (str) |
force a lowering variant (also settable via the |
|
|
int | None |
vectorization width for the copy |
|
|
mbarrier handle |
completion barrier |
|
|
|
CTA-group; |
|
|
int | None |
multicast CTA mask |
|
|
|
L2 cache eviction hint |
|
|
|
out-of-bounds fill policy ( |
|
|
str (e.g. |
TMA store-with-reduction mode |
|
|
bool |
prefetch the tensor map at kernel entry |
|
|
int | PrimExpr |
target CTA for a cross-CTA shared→shared copy |
|
|
uint32 | None |
pre-encoded MMA instruction descriptor |
|
|
bool |
per-thread shuffle reduction |
|
|
|
FP rounding mode for the packed form |
Three dispatch inputs are implicit, not config keys: the execution scope
(set by the namespace, then refined against the active thread set tracked through
control flow into inter/intra maps and a scope_kind), the operand
layouts (each Buffer.layout), and the target (the dispatch table is
keyed by its kind, e.g. "cuda").
Dispatch mechanism
Pipeline
Dispatch runs in the tirx.TilePrimitiveDispatch pass — the sole pass inside
LowerTIRx(), the first stage of the compile pipeline. The C++ mutator
TilePrimitiveDispatcher walks the IR and, per call:
resolves the
(inter, intra)execution split for the call’s scope from the active set tracked through control flow (if wg_id == ...,warp_id,T.ptx.elect_sync());builds a
DispatchContextcarryingtarget, scope, launch params, value ranges, and the encodedinter/intra+scope_kind;invokes the global FFI hook
tirx.f_op_dispatcher(Python) with the call and context, which returns aPrimFunc;splices that
PrimFuncbody in place of the call and drains side-effect callbacks (private allocs, device/host init statements).
If any TilePrimitiveCall survives lowering, a verifier makes it a fatal error.
Selection (run_dispatch)
The Python dispatcher holds a table _DISPATCH_TABLE keyed by
(Op, target_kind). Each entry is a list of cases, registered by backends
via @register_dispatch(op_name, target_kind, variant=..., priority=...,
when=[preds]). run_dispatch(op_call, sctx):
key = (op_call.op, sctx.target.kind.name); look up cases. None → error.If
op_call.dispatchis set, filter to that variant (error if unknown).Sort cases by
(-priority, variant)— highest priority first.For each case, evaluate its predicates; if any fails, record the reason and continue. If all pass, run the impl; on success return its
PrimFunc.An impl may still decline by raising
DispatchFail(e.g. a hardware constraint found while emitting) — the search continues.If every variant is rejected, raise a
RuntimeErrorlisting each variant’s rejection reason.
So dispatch is keyed by (primitive, target), then a priority-ordered,
predicate-guarded case list, with an optional dispatch= override.
Two recurring predicate helpers: validate_copy_op (both operands have a
layout, equal dtype, equal non-unit extents) and _all_threads_active (the
exec scope is full — laneid spans 32, etc., none of it narrowed by an
enclosing if), so a partial-warp copy is rejected rather than mis-lowered.
Dispatch by primitive
Each page below documents one primitive’s dispatch in detail — the variants, how each is selected, the algorithm it runs, the IR it emits, and when it declines.
See also
Tensor Layout — the
TileLayoutmodel dispatch reads from operands.Overview — execution scope, tensor layout, and tile primitive dispatch as the three core constructs.