
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_deep_dive_tensor_ir_tutorials_dlight_gpu_scheduling.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_deep_dive_tensor_ir_tutorials_dlight_gpu_scheduling.py:


.. _dlight_gpu_scheduling:

DLight: Rule-Based GPU Scheduling
==================================
TIR functions produced by Relax legalization need GPU-specific scheduling — thread binding,
loop tiling, shared memory usage — before they can run efficiently on a GPU. There are two
main approaches in TVM:

- **MetaSchedule**: explores a search space to find the best schedule. High quality, but
  compilation takes minutes to hours.
- **DLight**: applies pre-defined scheduling rules deterministically. No tuning required,
  compilation completes in seconds. Performance is excellent for well-known patterns
  (e.g., GEMM, GEMV in LLM workloads) and fair for the rest.

This tutorial covers how DLight works, what rules are available, how to diagnose scheduling
quality, and how to write custom rules.

.. contents:: Table of Contents
    :local:
    :depth: 1

.. GENERATED FROM PYTHON SOURCE LINES 43-48

Prepare a Model
---------------
We build a small model with ``nn.Module`` that is rich enough to trigger multiple DLight
rules: ``Linear`` layers produce GEMM (matrix multiplication) kernels, ``LayerNorm``
produces a general-reduction kernel, and ``ReLU`` is a simple elementwise op.

.. GENERATED FROM PYTHON SOURCE LINES 48-70

.. code-block:: Python


    import tvm
    from tvm import relax, tirx
    from tvm.relax.frontend import nn
    from tvm.s_tir import dlight as dl


    class DemoModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(768, 768)
            self.relu = nn.ReLU()
            self.norm = nn.LayerNorm(768)
            self.fc2 = nn.Linear(768, 256)

        def forward(self, x):
            x = self.norm(self.relu(self.fc1(x)))
            return self.fc2(x)


    mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor((1, 768), "float32")}})








.. GENERATED FROM PYTHON SOURCE LINES 71-72

Legalize Relax operators into TIR functions so that DLight has concrete kernels to schedule.

.. GENERATED FROM PYTHON SOURCE LINES 72-78

.. code-block:: Python


    device = tvm.cuda(0)
    target = tvm.target.Target.from_device(device)
    with target:
        mod = relax.get_pipeline("zero")(mod)








.. GENERATED FROM PYTHON SOURCE LINES 79-81

At this point every TIR function in ``mod`` is **unscheduled** — it has no thread bindings
and would not run efficiently on a GPU. Let's see what functions we have:

.. GENERATED FROM PYTHON SOURCE LINES 81-85

.. code-block:: Python

    for gv, func in mod.functions_items():
        if isinstance(func, tirx.PrimFunc):
            print(f"  {gv.name_hint}")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

      fused_matmul1_add1
      fused_matmul_add_relu
      layer_norm
      transpose
      transpose1




.. GENERATED FROM PYTHON SOURCE LINES 86-93

Basic Usage: ApplyDefaultSchedule
---------------------------------
``ApplyDefaultSchedule`` is an ``IRModule`` pass. It iterates over every TIR function in the
module and tries the given rules **in order**. For each function the first rule whose
``apply()`` returns a non-``None`` schedule wins; subsequent rules are skipped.
After scheduling, the function is marked with ``tirx.is_scheduled`` so it won't be
scheduled again by a later ``ApplyDefaultSchedule`` call.

.. GENERATED FROM PYTHON SOURCE LINES 95-97

Here we use a common subset of rules. The full catalog (including ``LowBatchGEMV``,
``Transpose``, ``RMSNorm``) is listed in the next section.

.. GENERATED FROM PYTHON SOURCE LINES 97-109

.. code-block:: Python


    with target:
        scheduled_mod = dl.ApplyDefaultSchedule(
            dl.gpu.Matmul(),  # GEMM: dense matrix multiplication
            dl.gpu.GEMV(),  # matrix-vector products
            dl.gpu.Reduction(),  # simple reductions (sum, max, ...)
            dl.gpu.GeneralReduction(),  # compound reductions (softmax, layer norm, ...)
            dl.gpu.Fallback(),  # catch-all for anything unmatched above
        )(mod)

    scheduled_mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
            matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
            for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
                for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul_rf_init"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            T.reads()
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                        for ax1_fused_0, u in T.grid(T.int64(48), 1):
                            with T.sblock("matmul_rf_update"):
                                vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                                v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                                vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                                T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                                T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                                matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
                for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul"):
                            vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            T.writes(matmul_intermediate_local[T.int64(0), v0])
                            with T.init():
                                matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                            matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
                for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0_fused_1 in range(T.int64(1)):
                        with T.sblock("T_add"):
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                            T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
                            T.writes(T_add_intermediate[T.int64(0), v0])
                            T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]

        @T.prim_func(private=True, s_tir=True)
        def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
            T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
            matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
            for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
                for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul_rf_init"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            T.reads()
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                        for ax1_fused_0, u in T.grid(T.int64(48), 1):
                            with T.sblock("matmul_rf_update"):
                                vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                                v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                                vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                                T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                                T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                                matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
                for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul"):
                            vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            T.writes(matmul_intermediate_local[T.int64(0), v0])
                            with T.init():
                                matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                            matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
                for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0_fused_1 in range(T.int64(1)):
                        with T.sblock("compute"):
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                            T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
                            T.writes(compute_intermediate[T.int64(0), v0])
                            compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
            T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            relu_sum_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
            relu_var_sum_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
            for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
                for ax0 in range(T.int64(1)):
                    for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                        for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                            with T.sblock("relu_sum"):
                                v0 = T.axis.spatial(T.int64(1), ax0)
                                v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                                T.reads(relu[T.int64(0), v1])
                                T.writes(relu_sum_shared[T.int64(0)])
                                with T.init():
                                    relu_sum_shared[T.int64(0)] = T.float32(0.0)
                                relu_sum_shared[T.int64(0)] = relu_sum_shared[T.int64(0)] + relu[T.int64(0), v1]
                for ax0 in range(T.int64(1)):
                    for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                        for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                            with T.sblock("relu_var_sum"):
                                v0 = T.axis.spatial(T.int64(1), ax0)
                                v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                                T.reads(relu[T.int64(0), v1], relu_sum_shared[T.int64(0)])
                                T.writes(relu_var_sum_shared[T.int64(0)])
                                with T.init():
                                    relu_var_sum_shared[T.int64(0)] = T.float32(0.0)
                                relu_var_sum_shared[T.int64(0)] = relu_var_sum_shared[T.int64(0)] + (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0)) * (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0))
                for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        with T.sblock("T_layer_norm"):
                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                            v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
                            T.reads(relu[T.int64(0), v1], relu_sum_shared[T.int64(0)], relu_var_sum_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
                            T.writes(T_layer_norm[T.int64(0), v1])
                            T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_var_sum_shared[T.int64(0)] / T.float32(768.0) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]

        @T.prim_func(private=True, s_tir=True)
        def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                    with T.sblock("T_transpose"):
                        v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
                        v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
                        T.reads(fc1_weight[v1, v0])
                        T.writes(T_transpose[v0, v1])
                        T_transpose[v0, v1] = fc1_weight[v1, v0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                    with T.sblock("T_transpose"):
                        v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
                        v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
                        T.reads(fc2_weight[v1, v0])
                        T.writes(T_transpose[v0, v1])
                        T_transpose[v0, v1] = fc2_weight[v1, v0]

        @R.function
        def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_ty=R.Tensor((768, 768), dtype="float32"))
                lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_ty=R.Tensor((1, 768), dtype="float32"))
                layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_ty=R.Tensor((1, 768), dtype="float32"))
                permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_ty=R.Tensor((768, 256), dtype="float32"))
                gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_ty=R.Tensor((1, 256), dtype="float32"))
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 110-112

Compared with the unscheduled IR, you can now see thread bindings
(``blockIdx.x``, ``threadIdx.x``, ...) and loop transformations in each TIR function.

.. GENERATED FROM PYTHON SOURCE LINES 114-162

Rule Catalog
------------
DLight ships a set of GPU scheduling rules. Each rule is a subclass of
``ScheduleRule`` and implements an ``apply(func, target, tunable)`` method that returns
a ``Schedule`` if the rule matches, or ``None`` to pass.

The built-in GPU rules, roughly from most specific to most general:

.. list-table::
   :header-rows: 1
   :widths: 20 40 40

   * - Rule
     - Pattern
     - Typical operators
   * - ``Matmul``
     - GEMM index pattern ``C[S,I,J] += A[S,I,K] * B[S,J,K]``
     - ``nn.Linear``, batched matmul
   * - ``GEMV``
     - Matrix-vector multiply (one dimension is 1)
     - single-batch decode in attention
   * - ``LowBatchGEMV``
     - Low-batch GEMM scheduled with a GEMV strategy
     - small-batch decode
   * - ``Reduction``
     - Simple accumulation ``X[...] += Y[...]``
     - sum, max, argmax
   * - ``GeneralReduction``
     - Spatial dims followed by reduction dims (``S* R*``)
     - softmax, layer norm, RMS norm
   * - ``Transpose``
     - Read/write indices are permutations of each other
     - 2-D transpose
   * - ``RMSNorm``
     - Contains an ``rsqrt`` operation
     - RMS normalization
   * - ``Fallback``
     - Any function (always matches)
     - generic catch-all

**Rule order matters.** ``ApplyDefaultSchedule`` stops at the first match, so:

- Put **specialized** rules first (``Matmul``, ``GEMV``) — they have strict matching
  conditions but produce high-quality schedules.
- Put **general** rules later (``GeneralReduction``, ``Fallback``) — they match broadly
  but with less optimal schedules.
- If you put ``Fallback`` first, it would "steal" every function and no specialized
  rule would ever run.

.. GENERATED FROM PYTHON SOURCE LINES 164-170

Diagnosing Schedule Quality
---------------------------
A common question is: *which rule scheduled which function?* ``ApplyDefaultSchedule``
does not log this directly, but you can figure it out by applying rules one at a time.

**Step 1**: Apply each rule individually and record which functions it claims.

.. GENERATED FROM PYTHON SOURCE LINES 170-194

.. code-block:: Python


    from collections import OrderedDict

    rules = OrderedDict(
        [
            ("Matmul", dl.gpu.Matmul()),
            ("GEMV", dl.gpu.GEMV()),
            ("LowBatchGEMV", dl.gpu.LowBatchGEMV()),
            ("Reduction", dl.gpu.Reduction()),
            ("GeneralReduction", dl.gpu.GeneralReduction()),
            ("Transpose", dl.gpu.Transpose()),
            ("RMSNorm", dl.gpu.RMSNorm()),
        ]
    )

    rule_assignment = {}
    for rule_name, rule in rules.items():
        with target:
            test_mod = dl.ApplyDefaultSchedule(rule)(mod)
        for gv, func in test_mod.functions_items():
            if isinstance(func, tirx.PrimFunc) and gv.name_hint not in rule_assignment:
                if "tirx.is_scheduled" in func.attrs and func.attrs["tirx.is_scheduled"] == 1:
                    rule_assignment[gv.name_hint] = rule_name








.. GENERATED FROM PYTHON SOURCE LINES 195-196

**Step 2**: Functions not claimed by any specialized rule will fall through to ``Fallback``.

.. GENERATED FROM PYTHON SOURCE LINES 196-210

.. code-block:: Python


    all_tir_funcs = [
        gv.name_hint for gv, func in mod.functions_items() if isinstance(func, tirx.PrimFunc)
    ]
    fallback_funcs = [name for name in all_tir_funcs if name not in rule_assignment]

    print("Rule assignments:")
    for name, rule_name in sorted(rule_assignment.items()):
        print(f"  {name:40s} -> {rule_name}")
    if fallback_funcs:
        print("Handled by Fallback (may have suboptimal performance):")
        for name in sorted(fallback_funcs):
            print(f"  {name}")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Rule assignments:
      fused_matmul1_add1                       -> Matmul
      fused_matmul_add_relu                    -> Matmul
      layer_norm                               -> Matmul
      transpose                                -> Matmul
      transpose1                               -> Matmul




.. GENERATED FROM PYTHON SOURCE LINES 211-216

If an important kernel lands in the Fallback bucket, you have three options:

1. Write a **custom DLight rule** for it (see below).
2. Use **MetaSchedule** to auto-tune that specific function.
3. Manually schedule it with the ``tvm.s_tir.Schedule`` API.

.. GENERATED FROM PYTHON SOURCE LINES 218-253

DLight vs MetaSchedule
----------------------
The two systems are complementary, not competing:

.. list-table::
   :header-rows: 1
   :widths: 20 40 40

   * -
     - DLight
     - MetaSchedule
   * - Mechanism
     - Deterministic rule matching
     - Search-space exploration
   * - Compile time
     - Seconds
     - Minutes to hours
   * - Performance
     - Excellent on known patterns, fair otherwise
     - Near-optimal with sufficient search budget
   * - Best for
     - Default path, rapid iteration, CI
     - Hot-spot tuning in production

A practical workflow:

1. Run ``ApplyDefaultSchedule`` with the full rule set to cover all functions.
2. Profile the compiled model to identify hot-spot kernels.
3. Use ``MetaScheduleTuneTIR`` to auto-tune only those kernels.

Note that ``MetaScheduleTuneTIR`` does **not** automatically skip functions already
scheduled by DLight — it processes every ``PrimFunc`` in the module. In practice this
is harmless (tuning an already-scheduled function simply re-explores its space), but if
you want to avoid the extra search cost, filter the module or use ``MetaScheduleTuneIRMod``
with ``op_names`` to target specific functions.

.. GENERATED FROM PYTHON SOURCE LINES 255-259

Writing a Custom Rule
---------------------
You can extend DLight by writing your own ``ScheduleRule``. The simplest way is
``ScheduleRule.from_callable``, which wraps a plain function into a rule **instance**.

.. GENERATED FROM PYTHON SOURCE LINES 259-289

.. code-block:: Python


    from tvm import s_tir
    from tvm.s_tir.dlight.analysis import normalize_prim_func
    from tvm.s_tir.dlight.base.schedule_rule import ScheduleRule


    @ScheduleRule.from_callable("MyTileAndBind")
    def my_tile_and_bind(func: tirx.PrimFunc, target: tvm.target.Target, tunable: bool):
        """A minimal rule: for single-block injective functions, tile and bind to GPU threads."""
        if not isinstance(func, tirx.PrimFunc):
            return None
        sch = s_tir.Schedule(func)
        # Use normalize_prim_func to get block info with correct spatial/reduction classification.
        # This is the same analysis used by built-in DLight rules.
        block_infos = normalize_prim_func(sch)
        if block_infos is None or len(block_infos) != 1:
            return None  # only handle single-block functions
        info = block_infos[0]
        if not info.is_injective():
            return None  # skip reductions — dom_kind() uses iter_type, not loop kind
        loops = sch.get_loops(info.block_rv)
        if len(loops) == 0:
            return None
        fused = sch.fuse(*loops)
        bx, tx = sch.split(fused, factors=[None, 256])
        sch.bind(bx, "blockIdx.x")
        sch.bind(tx, "threadIdx.x")
        return sch









.. GENERATED FROM PYTHON SOURCE LINES 290-292

Insert the custom rule into the rule chain. Note that ``from_callable`` returns an
**instance**, so pass it directly — do not call ``my_tile_and_bind()`` again.

.. GENERATED FROM PYTHON SOURCE LINES 292-303

.. code-block:: Python


    with target:
        custom_mod = dl.ApplyDefaultSchedule(
            dl.gpu.Matmul(),
            dl.gpu.GeneralReduction(),
            my_tile_and_bind,  # our custom rule, tried before Fallback
            dl.gpu.Fallback(),
        )(mod)

    custom_mod.show()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    # from tvm.script import ir as I
    # from tvm.script import tirx as T
    # from tvm.tirx.layout import Axis
    # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @T.prim_func(private=True, s_tir=True)
        def fused_matmul1_add1(layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims1: T.Buffer((T.int64(768), T.int64(256)), "float32"), fc2_bias: T.Buffer((T.int64(256),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
            T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(256)), scope="local")
            matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
            for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
                for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul_rf_init"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            T.reads()
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                        for ax1_fused_0, u in T.grid(T.int64(48), 1):
                            with T.sblock("matmul_rf_update"):
                                vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                                v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                                vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                                T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                                T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                                matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + layer_norm[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
                for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul"):
                            vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            T.writes(matmul_intermediate_local[T.int64(0), v0])
                            with T.init():
                                matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                            matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
                for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0_fused_1 in range(T.int64(1)):
                        with T.sblock("T_add"):
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                            T.reads(matmul_intermediate_local[T.int64(0), v0], fc2_bias[v0])
                            T.writes(T_add_intermediate[T.int64(0), v0])
                            T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + fc2_bias[v0]

        @T.prim_func(private=True, s_tir=True)
        def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(768)), "float32"), permute_dims: T.Buffer((T.int64(768), T.int64(768)), "float32"), fc1_bias: T.Buffer((T.int64(768),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(768)), "float32")):
            T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(768)), scope="local")
            matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(16), T.int64(1), T.int64(768)), scope="local")
            for ax0_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
                for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul_rf_init"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            T.reads()
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                        for ax1_fused_0, u in T.grid(T.int64(48), 1):
                            with T.sblock("matmul_rf_update"):
                                vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                                v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                                vax1_fused_0 = T.axis.reduce(T.int64(48), ax1_fused_0)
                                T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                                T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                                matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + x[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
                for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.sblock("matmul"):
                            vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax1_fused)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            T.writes(matmul_intermediate_local[T.int64(0), v0])
                            with T.init():
                                matmul_intermediate_local[T.int64(0), v0] = T.float32(0.0)
                            matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
                for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax0_fused_1 in range(T.int64(1)):
                        with T.sblock("compute"):
                            v0 = T.axis.spatial(T.int64(768), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                            T.reads(matmul_intermediate_local[T.int64(0), v0], fc1_bias[v0])
                            T.writes(compute_intermediate[T.int64(0), v0])
                            compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + fc1_bias[v0], T.float32(0.0))

        @T.prim_func(private=True, s_tir=True)
        def layer_norm(relu: T.Buffer((T.int64(1), T.int64(768)), "float32"), norm_weight: T.Buffer((T.int64(768),), "float32"), norm_bias: T.Buffer((T.int64(768),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(768)), "float32")):
            T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            relu_sum_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
            relu_var_sum_shared = T.sblock_alloc_buffer((T.int64(1),), scope="shared")
            for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
                for ax0 in range(T.int64(1)):
                    for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                        for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                            with T.sblock("relu_sum"):
                                v0 = T.axis.spatial(T.int64(1), ax0)
                                v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                                T.reads(relu[T.int64(0), v1])
                                T.writes(relu_sum_shared[T.int64(0)])
                                with T.init():
                                    relu_sum_shared[T.int64(0)] = T.float32(0.0)
                                relu_sum_shared[T.int64(0)] = relu_sum_shared[T.int64(0)] + relu[T.int64(0), v1]
                for ax0 in range(T.int64(1)):
                    for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                        for ax1_fused_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                            with T.sblock("relu_var_sum"):
                                v0 = T.axis.spatial(T.int64(1), ax0)
                                v1 = T.axis.reduce(T.int64(768), ax1_fused_0 * T.int64(256) + ax1_fused_1)
                                T.reads(relu[T.int64(0), v1], relu_sum_shared[T.int64(0)])
                                T.writes(relu_var_sum_shared[T.int64(0)])
                                with T.init():
                                    relu_var_sum_shared[T.int64(0)] = T.float32(0.0)
                                relu_var_sum_shared[T.int64(0)] = relu_var_sum_shared[T.int64(0)] + (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0)) * (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0))
                for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    for ax1_0 in T.serial(T.int64(3), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        with T.sblock("T_layer_norm"):
                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
                            v1 = T.axis.spatial(T.int64(768), ax1_0 * T.int64(256) + ax1_1)
                            T.reads(relu[T.int64(0), v1], relu_sum_shared[T.int64(0)], relu_var_sum_shared[T.int64(0)], norm_weight[v1], norm_bias[v1])
                            T.writes(T_layer_norm[T.int64(0), v1])
                            T_layer_norm[T.int64(0), v1] = (relu[T.int64(0), v1] - relu_sum_shared[T.int64(0)] / T.float32(768.0)) * T.rsqrt(relu_var_sum_shared[T.int64(0)] / T.float32(768.0) + T.float32(1.0000000000000001e-05)) * norm_weight[v1] + norm_bias[v1]

        @T.prim_func(private=True, s_tir=True)
        def transpose(fc1_weight: T.Buffer((T.int64(768), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(768)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0_ax1_fused_0 in T.thread_binding(T.int64(576), thread="blockIdx.x"):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                    with T.sblock("T_transpose"):
                        v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(768))
                        v1 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(768))
                        T.reads(fc1_weight[v1, v0])
                        T.writes(T_transpose[v0, v1])
                        T_transpose[v0, v1] = fc1_weight[v1, v0]

        @T.prim_func(private=True, s_tir=True)
        def transpose1(fc2_weight: T.Buffer((T.int64(256), T.int64(768)), "float32"), T_transpose: T.Buffer((T.int64(768), T.int64(256)), "float32")):
            T.func_attr({"op_pattern": 2, "tirx.is_scheduled": True, "tirx.noalias": True})
            # with T.sblock("root"):
            for ax0_ax1_fused_0 in T.thread_binding(T.int64(192), thread="blockIdx.x"):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                    with T.sblock("T_transpose"):
                        v0 = T.axis.spatial(T.int64(768), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(256))
                        v1 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(256))
                        T.reads(fc2_weight[v1, v0])
                        T.writes(T_transpose[v0, v1])
                        T_transpose[v0, v1] = fc2_weight[v1, v0]

        @R.function
        def forward(x: R.Tensor((1, 768), dtype="float32"), fc1_weight: R.Tensor((768, 768), dtype="float32"), fc1_bias: R.Tensor((768,), dtype="float32"), norm_weight: R.Tensor((768,), dtype="float32"), norm_bias: R.Tensor((768,), dtype="float32"), fc2_weight: R.Tensor((256, 768), dtype="float32"), fc2_bias: R.Tensor((256,), dtype="float32")) -> R.Tensor((1, 256), dtype="float32"):
            R.func_attr({"num_input": 1})
            cls = Module
            with R.dataflow():
                permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_ty=R.Tensor((768, 768), dtype="float32"))
                lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_ty=R.Tensor((1, 768), dtype="float32"))
                layer_norm = R.call_tir(cls.layer_norm, (lv, norm_weight, norm_bias), out_ty=R.Tensor((1, 768), dtype="float32"))
                permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_ty=R.Tensor((768, 256), dtype="float32"))
                gv = R.call_tir(cls.fused_matmul1_add1, (layer_norm, permute_dims1, fc2_bias), out_ty=R.Tensor((1, 256), dtype="float32"))
                R.output(gv)
            return gv





.. GENERATED FROM PYTHON SOURCE LINES 304-306

To build a production-quality rule, subclass ``ScheduleRule`` directly and implement
``apply()`` with full analysis logic (see ``tvm.s_tir.dlight.gpu.Matmul`` for an example).

.. GENERATED FROM PYTHON SOURCE LINES 308-317

Summary
-------
- **DLight** provides fast, deterministic GPU scheduling via rule matching.
- Rules are tried in order; the first match wins. Put specialized rules before general ones.
- Use the **single-rule probing** technique to diagnose which rule handles each function.
- Combine DLight with MetaSchedule: DLight for baseline coverage, MetaSchedule for hot-spot tuning.
- Extend DLight by writing custom ``ScheduleRule`` implementations.

For DLight's role in the broader optimization pipeline, see :ref:`customize_opt`.


.. _sphx_glr_download_deep_dive_tensor_ir_tutorials_dlight_gpu_scheduling.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: dlight_gpu_scheduling.ipynb <dlight_gpu_scheduling.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: dlight_gpu_scheduling.py <dlight_gpu_scheduling.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: dlight_gpu_scheduling.zip <dlight_gpu_scheduling.zip>`
