DLight: Rule-Based GPU Scheduling

TIR functions produced by Relax legalization need GPU-specific scheduling — thread binding, loop tiling, shared memory usage — before they can run efficiently on a GPU. There are two main approaches in TVM:

  • MetaSchedule: explores a search space to find the best schedule. High quality, but compilation takes minutes to hours.

  • DLight: applies pre-defined scheduling rules deterministically. No tuning required, compilation completes in seconds. Performance is excellent for well-known patterns (e.g., GEMM, GEMV in LLM workloads) and fair for the rest.

This tutorial covers how DLight works, what rules are available, how to diagnose scheduling quality, and how to write custom rules.

Prepare a Model

We build a small model with nn.Module that is rich enough to trigger multiple DLight rules: Linear layers produce GEMM (matrix multiplication) kernels, LayerNorm produces a general-reduction kernel, and ReLU is a simple elementwise op.

import tvm
from tvm import relax, tirx
from tvm.relax.frontend import nn
from tvm.s_tir import dlight as dl


class DemoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(768, 768)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(768)
        self.fc2 = nn.Linear(768, 256)

    def forward(self, x):
        x = self.norm(self.relu(self.fc1(x)))
        return self.fc2(x)


mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor((1, 768), "float32")}})

Legalize Relax operators into TIR functions so that DLight has concrete kernels to schedule.

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
with target:
    mod = relax.get_pipeline("zero")(mod)

At this point every TIR function in mod is unscheduled — it has no thread bindings and would not run efficiently on a GPU. Let’s see what functions we have:

for gv, func in mod.functions_items():
    if isinstance(func, tirx.PrimFunc):
        print(f"  {gv.name_hint}")
fused_matmul1_add1
fused_matmul_add_relu
layer_norm
transpose
transpose1

Basic Usage: ApplyDefaultSchedule

ApplyDefaultSchedule is an IRModule pass. It iterates over every TIR function in the module and tries the given rules in order. For each function the first rule whose apply() returns a non-None schedule wins; subsequent rules are skipped. After scheduling, the function is marked with tirx.is_scheduled so it won’t be scheduled again by a later ApplyDefaultSchedule call.

Here we use a common subset of rules. The full catalog (including LowBatchGEMV, Transpose, RMSNorm) is listed in the next section.

with target:
    scheduled_mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),  # GEMM: dense matrix multiplication
        dl.gpu.GEMV(),  # matrix-vector products
        dl.gpu.Reduction(),  # simple reductions (sum, max, ...)
        dl.gpu.GeneralReduction(),  # compound reductions (softmax, layer norm, ...)
        dl.gpu.Fallback(),  # catch-all for anything unmatched above
    )(mod)

scheduled_mod.show()
# from tvm.script import ir as I
# from tvm.script import tirx as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
        matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(48), 1):
                        with T.sblock("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul"):
                        vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.sblock("T_add"):
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
                        T.writes(T_add_intermediate[T.int64(0), v0])
                        T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
        T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
        matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(48), 1):
                        with T.sblock("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul"):
                        vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.sblock("compute"):
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
                        T.writes(compute_intermediate[T.int64(0), v0])
                        compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))

    @T.prim_func(private=True)
    def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
        T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        relu_red_temp_v0_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
        relu_red_temp_v1_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
        for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax0 in range(T.int64(1)):
                for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        with T.sblock("relu_red_temp"):
                            v0 = T.axis.spatial(T.int64(1), ax0)
                            v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                            T.reads(relu[T.int64(0), v1])
                            T.writes(relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)])
                            with T.init():
                                relu_red_temp_v0_shared[T.int64(0)] = T.float32(0.0)
                                relu_red_temp_v1_shared[T.int64(0)] = T.float32(0.0)
                            v_relu_red_temp_v0: T.float32 = relu_red_temp_v0_shared[T.int64(0)] + relu[T.int64(0), v1]
                            v_relu_red_temp_v1: T.float32 = relu_red_temp_v1_shared[T.int64(0)] + relu[T.int64(0), v1] * relu[T.int64(0), v1]
                            relu_red_temp_v0_shared[T.int64(0)] = v_relu_red_temp_v0
                            relu_red_temp_v1_shared[T.int64(0)] = v_relu_red_temp_v1
            for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                    with T.sblock("T_layer_norm"):
                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                        v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
                        T.reads(relu[T.int64(0), v1], relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
                        T.writes(T_layer_norm[T.int64(0), v1])
                        T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_red_temp_v1_shared[T.int64(0)] / T.float32(768.0) - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0) * (relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
        T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.sblock("T_transpose"):
                    v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
                    v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
                    T.reads(fc1_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc1_weight[v1, v0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
        T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.sblock("T_transpose"):
                    v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
                    v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
                    T.reads(fc2_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc2_weight[v1, v0]

    @R.function
    def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((768, 768), dtype="float32"))
            lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
            layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((768, 256), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            R.output(gv)
        return gv

Compared with the unscheduled IR, you can now see thread bindings (blockIdx.x, threadIdx.x, …) and loop transformations in each TIR function.

Rule Catalog

DLight ships a set of GPU scheduling rules. Each rule is a subclass of ScheduleRule and implements an apply(func, target, tunable) method that returns a Schedule if the rule matches, or None to pass.

The built-in GPU rules, roughly from most specific to most general:

Rule

Pattern

Typical operators

Matmul

GEMM index pattern C[S,I,J] += A[S,I,K] * B[S,J,K]

nn.Linear, batched matmul

GEMV

Matrix-vector multiply (one dimension is 1)

single-batch decode in attention

LowBatchGEMV

Low-batch GEMM scheduled with a GEMV strategy

small-batch decode

Reduction

Simple accumulation X[...] += Y[...]

sum, max, argmax

GeneralReduction

Spatial dims followed by reduction dims (S* R*)

softmax, layer norm, RMS norm

Transpose

Read/write indices are permutations of each other

2-D transpose

RMSNorm

Contains an rsqrt operation

RMS normalization

Fallback

Any function (always matches)

generic catch-all

Rule order matters. ApplyDefaultSchedule stops at the first match, so:

  • Put specialized rules first (Matmul, GEMV) — they have strict matching conditions but produce high-quality schedules.

  • Put general rules later (GeneralReduction, Fallback) — they match broadly but with less optimal schedules.

  • If you put Fallback first, it would “steal” every function and no specialized rule would ever run.

Diagnosing Schedule Quality

A common question is: which rule scheduled which function? ApplyDefaultSchedule does not log this directly, but you can figure it out by applying rules one at a time.

Step 1: Apply each rule individually and record which functions it claims.

from collections import OrderedDict

rules = OrderedDict(
    [
        ("Matmul", dl.gpu.Matmul()),
        ("GEMV", dl.gpu.GEMV()),
        ("LowBatchGEMV", dl.gpu.LowBatchGEMV()),
        ("Reduction", dl.gpu.Reduction()),
        ("GeneralReduction", dl.gpu.GeneralReduction()),
        ("Transpose", dl.gpu.Transpose()),
        ("RMSNorm", dl.gpu.RMSNorm()),
    ]
)

rule_assignment = {}
for rule_name, rule in rules.items():
    with target:
        test_mod = dl.ApplyDefaultSchedule(rule)(mod)
    for gv, func in test_mod.functions_items():
        if isinstance(func, tirx.PrimFunc) and gv.name_hint not in rule_assignment:
            if "tirx.is_scheduled" in func.attrs and func.attrs["tirx.is_scheduled"] == 1:
                rule_assignment[gv.name_hint] = rule_name

Step 2: Functions not claimed by any specialized rule will fall through to Fallback.

all_tir_funcs = [
    gv.name_hint for gv, func in mod.functions_items() if isinstance(func, tirx.PrimFunc)
]
fallback_funcs = [name for name in all_tir_funcs if name not in rule_assignment]

print("Rule assignments:")
for name, rule_name in sorted(rule_assignment.items()):
    print(f"  {name:40s} -> {rule_name}")
if fallback_funcs:
    print("Handled by Fallback (may have suboptimal performance):")
    for name in sorted(fallback_funcs):
        print(f"  {name}")
Rule assignments:
  fused_matmul1_add1                       -> Matmul
  fused_matmul_add_relu                    -> Matmul
  layer_norm                               -> Matmul
  transpose                                -> Matmul
  transpose1                               -> Matmul

If an important kernel lands in the Fallback bucket, you have three options:

  1. Write a custom DLight rule for it (see below).

  2. Use MetaSchedule to auto-tune that specific function.

  3. Manually schedule it with the tvm.s_tir.Schedule API.

DLight vs MetaSchedule

The two systems are complementary, not competing:

DLight

MetaSchedule

Mechanism

Deterministic rule matching

Search-space exploration

Compile time

Seconds

Minutes to hours

Performance

Excellent on known patterns, fair otherwise

Near-optimal with sufficient search budget

Best for

Default path, rapid iteration, CI

Hot-spot tuning in production

A practical workflow:

  1. Run ApplyDefaultSchedule with the full rule set to cover all functions.

  2. Profile the compiled model to identify hot-spot kernels.

  3. Use MetaScheduleTuneTIR to auto-tune only those kernels.

Note that MetaScheduleTuneTIR does not automatically skip functions already scheduled by DLight — it processes every PrimFunc in the module. In practice this is harmless (tuning an already-scheduled function simply re-explores its space), but if you want to avoid the extra search cost, filter the module or use MetaScheduleTuneIRMod with op_names to target specific functions.

Writing a Custom Rule

You can extend DLight by writing your own ScheduleRule. The simplest way is ScheduleRule.from_callable, which wraps a plain function into a rule instance.

from tvm import s_tir
from tvm.s_tir.dlight.analysis import normalize_prim_func
from tvm.s_tir.dlight.base.schedule_rule import ScheduleRule


@ScheduleRule.from_callable("MyTileAndBind")
def my_tile_and_bind(func: tirx.PrimFunc, target: tvm.target.Target, tunable: bool):
    """A minimal rule: for single-block injective functions, tile and bind to GPU threads."""
    if not isinstance(func, tirx.PrimFunc):
        return None
    sch = s_tir.Schedule(func)
    # Use normalize_prim_func to get block info with correct spatial/reduction classification.
    # This is the same analysis used by built-in DLight rules.
    block_infos = normalize_prim_func(sch)
    if block_infos is None or len(block_infos) != 1:
        return None  # only handle single-block functions
    info = block_infos[0]
    if not info.is_injective():
        return None  # skip reductions — dom_kind() uses iter_type, not loop kind
    loops = sch.get_loops(info.block_rv)
    if len(loops) == 0:
        return None
    fused = sch.fuse(*loops)
    bx, tx = sch.split(fused, factors=[None, 256])
    sch.bind(bx, "blockIdx.x")
    sch.bind(tx, "threadIdx.x")
    return sch

Insert the custom rule into the rule chain. Note that from_callable returns an instance, so pass it directly — do not call my_tile_and_bind() again.

with target:
    custom_mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.GeneralReduction(),
        my_tile_and_bind,  # our custom rule, tried before Fallback
        dl.gpu.Fallback(),
    )(mod)

custom_mod.show()
# from tvm.script import ir as I
# from tvm.script import tirx as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
        matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(48), 1):
                        with T.sblock("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul"):
                        vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.sblock("T_add"):
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
                        T.writes(T_add_intermediate[T.int64(0), v0])
                        T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
        T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
        matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(48), 1):
                        with T.sblock("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.sblock("matmul"):
                        vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.sblock("compute"):
                        v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
                        T.writes(compute_intermediate[T.int64(0), v0])
                        compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))

    @T.prim_func(private=True)
    def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
        T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        relu_red_temp_v0_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
        relu_red_temp_v1_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
        for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax0 in range(T.int64(1)):
                for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        with T.sblock("relu_red_temp"):
                            v0 = T.axis.spatial(T.int64(1), ax0)
                            v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                            T.reads(relu[T.int64(0), v1])
                            T.writes(relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)])
                            with T.init():
                                relu_red_temp_v0_shared[T.int64(0)] = T.float32(0.0)
                                relu_red_temp_v1_shared[T.int64(0)] = T.float32(0.0)
                            v_relu_red_temp_v0: T.float32 = relu_red_temp_v0_shared[T.int64(0)] + relu[T.int64(0), v1]
                            v_relu_red_temp_v1: T.float32 = relu_red_temp_v1_shared[T.int64(0)] + relu[T.int64(0), v1] * relu[T.int64(0), v1]
                            relu_red_temp_v0_shared[T.int64(0)] = v_relu_red_temp_v0
                            relu_red_temp_v1_shared[T.int64(0)] = v_relu_red_temp_v1
            for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                    with T.sblock("T_layer_norm"):
                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                        v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
                        T.reads(relu[T.int64(0), v1], relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
                        T.writes(T_layer_norm[T.int64(0), v1])
                        T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_red_temp_v1_shared[T.int64(0)] / T.float32(768.0) - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0) * (relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
        T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.sblock("T_transpose"):
                    v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
                    v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
                    T.reads(fc1_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc1_weight[v1, v0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
        T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
        # with T.sblock("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.sblock("T_transpose"):
                    v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
                    v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
                    T.reads(fc2_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc2_weight[v1, v0]

    @R.function
    def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((768, 768), dtype="float32"))
            lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
            layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((768, 256), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            R.output(gv)
        return gv

To build a production-quality rule, subclass ScheduleRule directly and implement apply() with full analysis logic (see tvm.s_tir.dlight.gpu.Matmul for an example).

Summary

  • DLight provides fast, deterministic GPU scheduling via rule matching.

  • Rules are tried in order; the first match wins. Put specialized rules before general ones.

  • Use the single-rule probing technique to diagnose which rule handles each function.

  • Combine DLight with MetaSchedule: DLight for baseline coverage, MetaSchedule for hot-spot tuning.

  • Extend DLight by writing custom ScheduleRule implementations.

For DLight’s role in the broader optimization pipeline, see Customize Optimization.

Gallery generated by Sphinx-Gallery