Note
You can click here to run the Jupyter notebook locally.
DLight: Rule-Based GPU Scheduling
TIR functions produced by Relax legalization need GPU-specific scheduling — thread binding, loop tiling, shared memory usage — before they can run efficiently on a GPU. There are two main approaches in TVM:
MetaSchedule: explores a search space to find the best schedule. High quality, but compilation takes minutes to hours.
DLight: applies pre-defined scheduling rules deterministically. No tuning required, compilation completes in seconds. Performance is excellent for well-known patterns (e.g., GEMM, GEMV in LLM workloads) and fair for the rest.
This tutorial covers how DLight works, what rules are available, how to diagnose scheduling quality, and how to write custom rules.
Prepare a Model
We build a small model with nn.Module that is rich enough to trigger multiple DLight
rules: Linear layers produce GEMM (matrix multiplication) kernels, LayerNorm
produces a general-reduction kernel, and ReLU is a simple elementwise op.
import tvm
from tvm import relax, tirx
from tvm.relax.frontend import nn
from tvm.s_tir import dlight as dl
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(768, 768)
self.relu = nn.ReLU()
self.norm = nn.LayerNorm(768)
self.fc2 = nn.Linear(768, 256)
def forward(self, x):
x = self.norm(self.relu(self.fc1(x)))
return self.fc2(x)
mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor((1, 768), "float32")}})
Legalize Relax operators into TIR functions so that DLight has concrete kernels to schedule.
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
with target:
mod = relax.get_pipeline("zero")(mod)
At this point every TIR function in mod is unscheduled — it has no thread bindings
and would not run efficiently on a GPU. Let’s see what functions we have:
for gv, func in mod.functions_items():
if isinstance(func, tirx.PrimFunc):
print(f" {gv.name_hint}")
fused_matmul1_add1
fused_matmul_add_relu
layer_norm
transpose
transpose1
Basic Usage: ApplyDefaultSchedule
ApplyDefaultSchedule is an IRModule pass. It iterates over every TIR function in the
module and tries the given rules in order. For each function the first rule whose
apply() returns a non-None schedule wins; subsequent rules are skipped.
After scheduling, the function is marked with tirx.is_scheduled so it won’t be
scheduled again by a later ApplyDefaultSchedule call.
Here we use a common subset of rules. The full catalog (including LowBatchGEMV,
Transpose, RMSNorm) is listed in the next section.
with target:
scheduled_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(), # GEMM: dense matrix multiplication
dl.gpu.GEMV(), # matrix-vector products
dl.gpu.Reduction(), # simple reductions (sum, max, ...)
dl.gpu.GeneralReduction(), # compound reductions (softmax, layer norm, ...)
dl.gpu.Fallback(), # catch-all for anything unmatched above
)(mod)
scheduled_mod.show()
# from tvm.script import ir as I
# from tvm.script import tirx as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
for ax1_fused_0, u in T.grid(T.int64(48), 1):
with T.sblock("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul"):
vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.sblock("T_add"):
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
T.writes(T_add_intermediate[T.int64(0), v0])
T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
for ax1_fused_0, u in T.grid(T.int64(48), 1):
with T.sblock("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul"):
vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.sblock("compute"):
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
T.writes(compute_intermediate[T.int64(0), v0])
compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))
@T.prim_func(private=True)
def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
relu_red_temp_v0_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
relu_red_temp_v1_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax0 in range(T.int64(1)):
for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.sblock("relu_red_temp"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
T.reads(relu[T.int64(0), v1])
T.writes(relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)])
with T.init():
relu_red_temp_v0_shared[T.int64(0)] = T.float32(0.0)
relu_red_temp_v1_shared[T.int64(0)] = T.float32(0.0)
v_relu_red_temp_v0: T.float32 = relu_red_temp_v0_shared[T.int64(0)] + relu[T.int64(0), v1]
v_relu_red_temp_v1: T.float32 = relu_red_temp_v1_shared[T.int64(0)] + relu[T.int64(0), v1] * relu[T.int64(0), v1]
relu_red_temp_v0_shared[T.int64(0)] = v_relu_red_temp_v0
relu_red_temp_v1_shared[T.int64(0)] = v_relu_red_temp_v1
for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.sblock("T_layer_norm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
T.reads(relu[T.int64(0), v1], relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
T.writes(T_layer_norm[T.int64(0), v1])
T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_red_temp_v1_shared[T.int64(0)] / T.float32(768.0) - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0) * (relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.sblock("T_transpose"):
v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
T.reads(fc1_weight[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = fc1_weight[v1, v0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.sblock("T_transpose"):
v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
T.reads(fc2_weight[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = fc2_weight[v1, v0]
@R.function
def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((768, 768), dtype="float32"))
lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((768, 256), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
R.output(gv)
return gv
Compared with the unscheduled IR, you can now see thread bindings
(blockIdx.x, threadIdx.x, …) and loop transformations in each TIR function.
Rule Catalog
DLight ships a set of GPU scheduling rules. Each rule is a subclass of
ScheduleRule and implements an apply(func, target, tunable) method that returns
a Schedule if the rule matches, or None to pass.
The built-in GPU rules, roughly from most specific to most general:
Rule |
Pattern |
Typical operators |
|---|---|---|
|
GEMM index pattern |
|
|
Matrix-vector multiply (one dimension is 1) |
single-batch decode in attention |
|
Low-batch GEMM scheduled with a GEMV strategy |
small-batch decode |
|
Simple accumulation |
sum, max, argmax |
|
Spatial dims followed by reduction dims ( |
softmax, layer norm, RMS norm |
|
Read/write indices are permutations of each other |
2-D transpose |
|
Contains an |
RMS normalization |
|
Any function (always matches) |
generic catch-all |
Rule order matters. ApplyDefaultSchedule stops at the first match, so:
Put specialized rules first (
Matmul,GEMV) — they have strict matching conditions but produce high-quality schedules.Put general rules later (
GeneralReduction,Fallback) — they match broadly but with less optimal schedules.If you put
Fallbackfirst, it would “steal” every function and no specialized rule would ever run.
Diagnosing Schedule Quality
A common question is: which rule scheduled which function? ApplyDefaultSchedule
does not log this directly, but you can figure it out by applying rules one at a time.
Step 1: Apply each rule individually and record which functions it claims.
from collections import OrderedDict
rules = OrderedDict(
[
("Matmul", dl.gpu.Matmul()),
("GEMV", dl.gpu.GEMV()),
("LowBatchGEMV", dl.gpu.LowBatchGEMV()),
("Reduction", dl.gpu.Reduction()),
("GeneralReduction", dl.gpu.GeneralReduction()),
("Transpose", dl.gpu.Transpose()),
("RMSNorm", dl.gpu.RMSNorm()),
]
)
rule_assignment = {}
for rule_name, rule in rules.items():
with target:
test_mod = dl.ApplyDefaultSchedule(rule)(mod)
for gv, func in test_mod.functions_items():
if isinstance(func, tirx.PrimFunc) and gv.name_hint not in rule_assignment:
if "tirx.is_scheduled" in func.attrs and func.attrs["tirx.is_scheduled"] == 1:
rule_assignment[gv.name_hint] = rule_name
Step 2: Functions not claimed by any specialized rule will fall through to Fallback.
all_tir_funcs = [
gv.name_hint for gv, func in mod.functions_items() if isinstance(func, tirx.PrimFunc)
]
fallback_funcs = [name for name in all_tir_funcs if name not in rule_assignment]
print("Rule assignments:")
for name, rule_name in sorted(rule_assignment.items()):
print(f" {name:40s} -> {rule_name}")
if fallback_funcs:
print("Handled by Fallback (may have suboptimal performance):")
for name in sorted(fallback_funcs):
print(f" {name}")
Rule assignments:
fused_matmul1_add1 -> Matmul
fused_matmul_add_relu -> Matmul
layer_norm -> Matmul
transpose -> Matmul
transpose1 -> Matmul
If an important kernel lands in the Fallback bucket, you have three options:
Write a custom DLight rule for it (see below).
Use MetaSchedule to auto-tune that specific function.
Manually schedule it with the
tvm.s_tir.ScheduleAPI.
DLight vs MetaSchedule
The two systems are complementary, not competing:
DLight |
MetaSchedule |
|
|---|---|---|
Mechanism |
Deterministic rule matching |
Search-space exploration |
Compile time |
Seconds |
Minutes to hours |
Performance |
Excellent on known patterns, fair otherwise |
Near-optimal with sufficient search budget |
Best for |
Default path, rapid iteration, CI |
Hot-spot tuning in production |
A practical workflow:
Run
ApplyDefaultSchedulewith the full rule set to cover all functions.Profile the compiled model to identify hot-spot kernels.
Use
MetaScheduleTuneTIRto auto-tune only those kernels.
Note that MetaScheduleTuneTIR does not automatically skip functions already
scheduled by DLight — it processes every PrimFunc in the module. In practice this
is harmless (tuning an already-scheduled function simply re-explores its space), but if
you want to avoid the extra search cost, filter the module or use MetaScheduleTuneIRMod
with op_names to target specific functions.
Writing a Custom Rule
You can extend DLight by writing your own ScheduleRule. The simplest way is
ScheduleRule.from_callable, which wraps a plain function into a rule instance.
from tvm import s_tir
from tvm.s_tir.dlight.analysis import normalize_prim_func
from tvm.s_tir.dlight.base.schedule_rule import ScheduleRule
@ScheduleRule.from_callable("MyTileAndBind")
def my_tile_and_bind(func: tirx.PrimFunc, target: tvm.target.Target, tunable: bool):
"""A minimal rule: for single-block injective functions, tile and bind to GPU threads."""
if not isinstance(func, tirx.PrimFunc):
return None
sch = s_tir.Schedule(func)
# Use normalize_prim_func to get block info with correct spatial/reduction classification.
# This is the same analysis used by built-in DLight rules.
block_infos = normalize_prim_func(sch)
if block_infos is None or len(block_infos) != 1:
return None # only handle single-block functions
info = block_infos[0]
if not info.is_injective():
return None # skip reductions — dom_kind() uses iter_type, not loop kind
loops = sch.get_loops(info.block_rv)
if len(loops) == 0:
return None
fused = sch.fuse(*loops)
bx, tx = sch.split(fused, factors=[None, 256])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
return sch
Insert the custom rule into the rule chain. Note that from_callable returns an
instance, so pass it directly — do not call my_tile_and_bind() again.
with target:
custom_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GeneralReduction(),
my_tile_and_bind, # our custom rule, tried before Fallback
dl.gpu.Fallback(),
)(mod)
custom_mod.show()
# from tvm.script import ir as I
# from tvm.script import tirx as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
for ax1_fused_0, u in T.grid(T.int64(48), 1):
with T.sblock("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul"):
vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.sblock("T_add"):
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
T.writes(T_add_intermediate[T.int64(0), v0])
T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
for ax1_fused_0, u in T.grid(T.int64(48), 1):
with T.sblock("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.sblock("matmul"):
vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.sblock("compute"):
v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
T.writes(compute_intermediate[T.int64(0), v0])
compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))
@T.prim_func(private=True)
def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
relu_red_temp_v0_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
relu_red_temp_v1_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax0 in range(T.int64(1)):
for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.sblock("relu_red_temp"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
T.reads(relu[T.int64(0), v1])
T.writes(relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)])
with T.init():
relu_red_temp_v0_shared[T.int64(0)] = T.float32(0.0)
relu_red_temp_v1_shared[T.int64(0)] = T.float32(0.0)
v_relu_red_temp_v0: T.float32 = relu_red_temp_v0_shared[T.int64(0)] + relu[T.int64(0), v1]
v_relu_red_temp_v1: T.float32 = relu_red_temp_v1_shared[T.int64(0)] + relu[T.int64(0), v1] * relu[T.int64(0), v1]
relu_red_temp_v0_shared[T.int64(0)] = v_relu_red_temp_v0
relu_red_temp_v1_shared[T.int64(0)] = v_relu_red_temp_v1
for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.sblock("T_layer_norm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
T.reads(relu[T.int64(0), v1], relu_red_temp_v0_shared[T.int64(0)], relu_red_temp_v1_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
T.writes(T_layer_norm[T.int64(0), v1])
T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_red_temp_v1_shared[T.int64(0)] / T.float32(768.0) - relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0) * (relu_red_temp_v0_shared[T.int64(0)] / T.float32(768.0)) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.sblock("T_transpose"):
v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
T.reads(fc1_weight[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = fc1_weight[v1, v0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
# with T.sblock("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.sblock("T_transpose"):
v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
T.reads(fc2_weight[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = fc2_weight[v1, v0]
@R.function
def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((768, 768), dtype="float32"))
lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_sinfo=R.Tensor((1, 768), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((768, 256), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
R.output(gv)
return gv
To build a production-quality rule, subclass ScheduleRule directly and implement
apply() with full analysis logic (see tvm.s_tir.dlight.gpu.Matmul for an example).
Summary
DLight provides fast, deterministic GPU scheduling via rule matching.
Rules are tried in order; the first match wins. Put specialized rules before general ones.
Use the single-rule probing technique to diagnose which rule handles each function.
Combine DLight with MetaSchedule: DLight for baseline coverage, MetaSchedule for hot-spot tuning.
Extend DLight by writing custom
ScheduleRuleimplementations.
For DLight’s role in the broader optimization pipeline, see Customize Optimization.