Defining a function
A kernel is a @T.prim_func (like scale in Your first kernel), or a
@T.jit when it has compile-time parameters (see the last section). This
chapter covers the parameter list — how to declare buffers, what types you can
pass, symbolic shapes, and the prim_func / jit distinction.
Declaring buffer parameters
There are two equivalent ways to take a tensor parameter:
Handle + match_buffer. Take a
T.handle(an opaque data pointer) and bind it in the body withT.match_buffer. This is the explicit form and the one that exposes every descriptor field —layout,elem_offset,scope,align, and symbolic shapes:@T.prim_func def f(A_ptr: T.handle, B_ptr: T.handle): A = T.match_buffer(A_ptr, (256,), "float32", align=16) B = T.match_buffer(B_ptr, (256,), "float32") ...
T.Buffer annotation. Annotate the parameter directly. This is the concise form — equivalent to a handle bound with
match_bufferusing the defaults:@T.prim_func def f(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): ...
Both give you a Buffer you index with A[i] / A[i, j]. Use T.Buffer
for the common case; drop to T.handle + match_buffer when you need a custom
layout/offset/scope/alignment or a symbolic shape.
What the parameter list accepts
A PrimFunc parameter is one of the following. The third column is what you
pass on the Python side when you call the compiled Executable:
Annotation |
Is |
Pass at call time |
|---|---|---|
|
a tensor parameter (shape + dtype fixed) |
a tensor on the right device |
|
an opaque data pointer (bind with |
a tensor |
|
a runtime scalar |
a Python |
|
a compile-time constant |
supplied to |
Tensors may be CUDA torch tensors (zero-copy via DLPack) or
tvm.runtime.tensor(...). Arguments are positional and match the parameter
order. For example, a kernel with a scalar parameter:
@T.prim_func
def scal(A_ptr: T.handle, B_ptr: T.handle, s: T.float32):
A = T.match_buffer(A_ptr, (256,), "float32")
B = T.match_buffer(B_ptr, (256,), "float32")
T.device_entry(); bx = T.cta_id([1]); tx = T.thread_id([256])
B[tx] = A[tx] * s
exe(a, b, 3.0) # pass the scalar as a Python float
Symbolic shapes
For a size that varies at run time, declare a free symbolic extent with
T.int32() and use it in the buffer shape. Its value is inferred from the
passed tensor at run time, so a single compiled kernel handles any size:
@T.prim_func
def scale_dyn(a: T.handle, b: T.handle):
n = T.int32() # free symbolic extent
A = T.match_buffer(a, (n,), "float32")
B = T.match_buffer(b, (n,), "float32")
T.device_entry()
bx = T.cta_id([1]); tx = T.thread_id([1])
for i in range(n): # loop / launch bounds may use n
B[i] = A[i] * T.float32(2.0)
exe = tvm.compile(tvm.IRModule({"main": scale_dyn}),
target=tvm.target.Target("cuda"), tir_pipeline="tirx")
exe(torch.rand(100, device="cuda"), torch.empty(100, device="cuda")) # n = 100
exe(torch.rand(200, device="cuda"), torch.empty(200, device="cuda")) # n = 200, same kernel
Both match_buffer calls share n, so the two shapes are constrained equal;
n is never passed explicitly — it comes from the tensor.
In the generated CUDA, n is just a runtime kernel argument; the host launcher
reads it from the tensor’s shape and passes it, and the loop bound uses it
(boilerplate elided):
extern "C" __global__ void
scale_dyn_kernel(float* __restrict__ A_ptr, float* __restrict__ B_ptr, int n) {
for (int i = 0; i < n; ++i) {
B_ptr[i] = A_ptr[i] * 2.0f;
}
}
Note
You passed only two tensors, yet the kernel takes a third argument n — who
supplies it? A compiled Executable has two halves: a host launcher and
the device kernel above. When you call exe(a, b), the host launcher
unpacks the two tensors, reads n from a’s shape (a was matched as
(n,)), checks that b agrees, computes the launch configuration, and then
invokes the device kernel — forwarding the data pointers and the resolved
n as explicit arguments. Nothing passes n by hand; the host side derives
it from the tensor metadata. The pass that does this is
tirx.transform.SplitHostDevice (followed by tirx.transform.MakePackedAPI).
You can see it in the IR. Before the split, the lowered module is a single merged function (trimmed):
@T.prim_func
def main(a: T.handle, b: T.handle):
n = T.int32()
A = T.match_buffer(a, (n,))
B = T.match_buffer(b, (n,))
with T.launch_thread("blockIdx.x", 1), T.launch_thread("threadIdx.x", 1):
for i in range(n):
B[i] = A[i] * T.float32(2.0)
After SplitHostDevice, it is two functions — a device kernel that takes
n as a parameter, and a host main that calls it, forwarding n (the
trailing 1, 1 are the grid/block launch dims):
@T.prim_func # device
def scale_dyn_kernel(A_ptr: T.handle("float32"), B_ptr: T.handle("float32"), n: T.int32):
...
for i in range(n):
B[i] = A[i] * T.float32(2.0)
@T.prim_func # host
def main(a: T.handle, b: T.handle):
n = T.int32()
A = T.match_buffer(a, (n,))
B = T.match_buffer(b, (n,))
T.call_packed("scale_dyn_kernel", A.data, B.data, n, 1, 1) # n forwarded
MakePackedAPI then fills in where n comes from — reading it from the
argument’s shape (essentially n = a.shape[0]) — and adds the dtype / shape /
device checks (e.g. asserting B.shape[0] == n):
n = T.Cast("int32", T.tvm_struct_get(a_shape, 0, 17, "int64")) # = a.shape[0]
@T.prim_func vs @T.jit
@T.prim_funcparses the function immediately into aPrimFunc. Sizes are whatever you wrote — concrete ints, or runtime-symbolic vars (above).@T.jitdefers parsing until you call.specialize(**constexpr): parameters annotatedT.constexprare baked in as compile-time constants and the result is an ordinaryPrimFunc. Use it when you want sizes/flags fixed at compile time (so the compiler can unroll, statically size shared memory, etc.). Referencing a constexpr inside an annotation (e.g.T.Buffer((N,), ...)) requiresfrom __future__ import annotationsat the top of the file.
@T.jit
def add(A: T.Buffer((N,), "float32"), B: T.Buffer((N,), "float32"),
C: T.Buffer((N,), "float32"), *, N: T.constexpr):
T.device_entry(); bx = T.cta_id([1]); tx = T.thread_id([N])
C[tx] = A[tx] + B[tx]
kernel = add.specialize(N=256) # -> a PrimFunc with N = 256 baked in
So: a symbolic shape is one kernel whose size is resolved at run time; a constexpr + jit produces a specialized kernel per value, resolved at compile time.
Launch parameters
T.device_entry()
T.device_entry() is a flat marker (no with) that splits the function:
everything before it is host code — the T.match_buffer parameter binding
and any shape reads — and everything after it is the device kernel body. It
lowers to an AttrStmt("tirx.device_entry", ...) and is exactly the boundary the
host/device split cuts along (the merged-vs-split modules shown above are split
here).
Scope ids
After device_entry you declare the thread hierarchy with scope-id intrinsics
— each takes its launch extent as a list:
T.device_entry()
bx, by = T.cta_id([GM, GN]) # blockIdx.x / .y (grid extents)
warp_id = T.warp_id([4]) # cta -> warp
lane_id = T.lane_id([32]) # warp -> thread
tx = T.thread_id([128]) # cta -> flat thread id
Available ids include cta_id, thread_id, warp_id, warpgroup_id,
warp_id_in_wg, lane_id, cluster_id, cta_id_in_cluster. (The legacy
T.launch_thread exists but native TIRx uses device_entry + scope-ids.)
Thread-block clusters (Hopper/Blackwell) are declared with cluster_id
(kernel → cluster) and cta_id_in_cluster (cluster → cta). The
cta_id_in_cluster extent is the cluster’s CTA dimension; its preferred=
argument sets the preferred cluster dimension (CUDA 12.8+):
cid = T.cluster_id([NUM_CLUSTERS]) # kernel -> cluster (grid of clusters)
rank = T.cta_id_in_cluster([CLUSTER_SIZE], # cluster -> cta
preferred=[CLUSTER_SIZE])
# -> cluster_dim = CLUSTER_SIZE, preferred_cluster_dim = CLUSTER_SIZE
These become the CLUSTER_DIMENSION / PREFERRED_CLUSTER_DIMENSION launch
attributes in the config below. (cta_id and cta_id_in_cluster also take an
optional preferred=.) In the device code they lower to reads of the cluster
PTX special registers:
int cid = ...; // mov.u32 %0, %clusterid.x; (cluster index)
int rank = ...; // mov.u32 %0, %cluster_ctarank; (CTA rank within the cluster)
The cluster dimensions themselves are not in the device code — they are set at launch time via the attributes above.
Launching the kernel
During lowering the compiler extracts every launch parameter the kernel uses —
the grid and block dimensions, plus the dynamic shared-memory size if any — into
the device function’s tirx.kernel_launch_params attribute. For the scale
kernel that list is ["blockIdx.x", "threadIdx.x"]; the host launcher computes
each one’s extent (from the scope-id extents and any symbolic shapes) and supplies
them alongside the kernel arguments.
The block size also drives the kernel’s __launch_bounds__. The first argument
(max threads per block) is set automatically from the thread extent. To also set
the second argument — the minimum blocks per SM, an occupancy hint — add
T.attr({"tirx.launch_bounds_min_blocks_per_sm": N}) in the device region (note:
T.attr, not func_attr):
T.device_entry()
T.attr({"tirx.launch_bounds_min_blocks_per_sm": 2}) # second launch-bounds arg
bx = T.cta_id([1]); tx = T.thread_id([256])
...
extern "C" __global__ void __launch_bounds__(256, 2) scale_kernel(...) { ... }
Without the attr the second argument is omitted (just __launch_bounds__(256)).
At run time the kernel is launched through the CUDA Driver API. TVM’s CUDA
runtime loads the module (cuModuleLoadData), fetches the function
(cuModuleGetFunction, cached), and calls cuLaunchKernelEx with a
CUlaunchConfig. Besides the grid/block dims, dynamic shared size, and stream,
the config carries a list of launch attributes — the thread-block cluster
dimension and preferred cluster dimension (Hopper/Blackwell), plus optional
programmatic-dependent-launch and cooperative-launch flags. From
src/backend/cuda/runtime/cuda_module.cc:
std::vector<CUlaunchAttribute> attrs;
// 1) thread-block cluster dimension
if (wl.cluster_dim(0) != 1 || wl.cluster_dim(1) != 1 || wl.cluster_dim(2) != 1) {
CUlaunchAttribute attr{};
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = wl.cluster_dim(0);
attr.value.clusterDim.y = wl.cluster_dim(1);
attr.value.clusterDim.z = wl.cluster_dim(2);
attrs.push_back(attr);
}
// 1b) preferred cluster dimension (CUDA 12.8+); (2) programmatic stream
// serialization and (3) cooperative launch are appended the same way
if (wl.preferred_cluster_dim(0) != 1 || wl.preferred_cluster_dim(1) != 1 ||
wl.preferred_cluster_dim(2) != 1) {
CUlaunchAttribute attr{};
attr.id = CU_LAUNCH_ATTRIBUTE_PREFERRED_CLUSTER_DIMENSION;
attr.value.clusterDim.x = wl.preferred_cluster_dim(0);
attr.value.clusterDim.y = wl.preferred_cluster_dim(1);
attr.value.clusterDim.z = wl.preferred_cluster_dim(2);
attrs.push_back(attr);
}
CUlaunchConfig config{};
config.gridDimX = wl.grid_dim(0);
config.gridDimY = wl.grid_dim(1);
config.gridDimZ = wl.grid_dim(2);
config.blockDimX = wl.block_dim(0);
config.blockDimY = wl.block_dim(1);
config.blockDimZ = wl.block_dim(2);
config.sharedMemBytes = wl.dyn_shmem_size;
config.hStream = strm;
config.attrs = attrs.empty() ? nullptr : attrs.data();
config.numAttrs = static_cast<unsigned int>(attrs.size());
CUresult result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr);
Here wl is the resolved workload (the grid/block/cluster extents derived from
the launch parameters), fcache_[device_id] is the cached CUfunction, and
void_args are the kernel arguments — the data pointers plus scalars like the
symbolic n.