TIRx lowering pipeline
tvm.compile(mod, target, tir_pipeline="tirx") runs an authored TIRx module
through the tirx pipeline — an ordered sequence of TIR passes that turns the
high-level constructs you write (tile primitives, TileLayout-typed buffers,
execution-scope ids) into split host + device functions, which the CUDA
backend then renders to source. The pipeline is defined in
python/tvm/tirx/compilation_pipeline.py (tirx_pipeline); this page walks the
passes in order.
Where it sits
tvm.compile first binds the target, runs the tirx pipeline (the module-level
passes below), then applies finalization passes separately to the host and
device functions, and finally hands each device function to the CUDA code
generator:
authored TIRx ──BindTarget──▶ tirx_pipeline ──▶ host func ──host finalize──▶ C/LLVM
│
└──────────▶ device func ──device finalize──▶ CUDA
The passes
The tirx_pipeline module pass applies this exact sequence (a few are gated by
PassContext config):
# |
Pass |
What it does |
|---|---|---|
1 |
|
the core lowering — see Inside LowerTIRx below |
2 |
|
merges equivalent thread-axis bindings so each |
3 |
|
statement-level arithmetic simplification (the arith analyzer) |
4 |
|
lowers remaining opaque TIRx constructs to plain TIR |
5 |
|
flattens multi-dimensional |
6 |
|
rewrites |
7 |
|
narrows index/loop |
8 |
|
turns |
9 |
|
unrolls loops marked |
10 |
|
simplify again, now that vectorize/unroll exposed constants |
11 |
|
hoists repeated subexpressions into temporaries (skipped if
|
12 |
|
rewrites |
13 |
|
checks no host-side code directly dereferences device memory (a safety gate) |
14 |
|
marks the single PrimFunc as the module entry point |
15 |
|
splits each kernel into a host function and a device function at the
|
16 |
|
rewrites the host function to the packed-func ABI (the launcher TVM calls) |
17 |
|
legalizes |
18 |
|
legalizes |
Finalization then runs per function kind:
host:
LowerTVMBuiltin(lowertvm_*builtins),LowerIntrin(target-specific intrinsics)device:
LowerWarpMemory(warp-scoped buffers → shuffles),StmtSimplify,LowerIntrin
Inside LowerTIRx
LowerTIRx is itself a small sequence (src/tirx/transform/lower_tirx.cc):
LowerTIRx = Sequential([ TilePrimitiveDispatch, LowerTIRxCleanup ])
``TilePrimitiveDispatch`` replaces every
TilePrimitiveCall(copy,gemm,reduction, …) with the body emitted by its selected backend dispatch — the variant-selection and codegen described in Tile Primitives.``LowerTIRxCleanup`` runs the
LayoutApplier: it resolves everyTileLayout-typed buffer access into concrete physical address arithmetic (addr = data + elem_offset + layout.apply(coord)), flattens the buffers, and lowers the execution-scope ids (T.cta_id/T.thread_id/ … →blockIdx/threadIdxvialaunch_thread).
So after LowerTIRx the module is plain TIR: no tile primitives, no
TileLayout indirection, scope ids resolved to thread axes.
A worked example
Take a one-line scale kernel:
@T.prim_func
def scale(A_ptr: T.handle, B_ptr: T.handle):
A = T.match_buffer(A_ptr, (256,), "float32")
B = T.match_buffer(B_ptr, (256,), "float32")
T.device_entry(); bx = T.cta_id([1]); tx = T.thread_id([256])
B[tx] = A[tx] * T.float32(2.0)
After ``LowerTIRx`` the scope ids are real thread axes and the layout is applied
(A_1 / B_1 are the flattened 1-D views):
with T.launch_thread("blockIdx.x", 1) as blockIdx_x:
threadIdx_x = T.launch_thread("threadIdx.x", 256)
bx: T.let = blockIdx_x
tx: T.let = threadIdx_x
B_1[threadIdx_x] = A_1[threadIdx_x] * T.float32(2.0)
After ``SplitHostDevice`` + ``MakePackedAPI`` the one function has become two — a host launcher and a device kernel:
@I.ir_module
class Module:
def main(...): # host: packed-API launcher (computes the grid/block, launches)
...
def scale_kernel(...): # device: the __global__ body, run on the GPU
The CUDA backend then renders scale_kernel to the __global__ function
(B_ptr[threadIdx.x] = A_ptr[threadIdx.x] * 2.0f).
Reproduce it yourself
You can run any prefix of the pipeline by hand to inspect a stage — this is how the IR snippets across these docs were produced:
from tvm.tirx import transform as TT
target = tvm.target.Target("cuda")
mod = TT.BindTarget(target.with_host("llvm"))(tvm.IRModule({"main": scale}))
mod = TT.LowerTIRx()(mod) # tile primitives dispatched, layouts applied
print(mod.script()) # inspect the lowered TIRx IR
Or compile the whole module and read the generated CUDA:
exe = tvm.compile(tvm.IRModule({"main": scale}), target=target, tir_pipeline="tirx")
print(exe.mod.imports[0].inspect_source())