
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/tutorials/optimize_llm.py"

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        You can click :ref:`here <sphx_glr_download_how_to_tutorials_optimize_llm.py>` to run the Jupyter notebook locally.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_how_to_tutorials_optimize_llm.py:


.. _opt_llm:

Optimize Large Language Model
=============================
As large language models (LLMs) have become a popular research topic in many different fields,
deploying them on cloud and edge devices has become a challenging task. In this tutorial, we will
demonstrate how to optimize a large language model using Apache TVM. We will use a pre-trained
TinyLlama model from Hugging Face and deploy it on various devices.

.. GENERATED FROM PYTHON SOURCE LINES 31-48

Review Overall Flow
-------------------
.. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg
   :align: center
   :width: 80%

The overall flow consists of the following steps:

- **Construct or Import a Model**: Construct a neural network model or import a pre-trained
  model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains
  all the information needed for compilation, including high-level Relax functions for
  computational graph, and low-level TensorIR functions for tensor program.
- **Perform Composable Optimizations**: Perform a series of optimization transformations,
  such as graph optimizations, tensor program optimizations, and library dispatching.
- **Build and Universal Deployment**: Build the optimized model to a deployable module to the
  universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.


.. GENERATED FROM PYTHON SOURCE LINES 51-57

Construct the model architecture
--------------------------------
We will use a pre-trained TinyLlama model from Hugging Face. However, usually we only load the
pre-trained weight from Hugging Face but not the model architecture. We need to construct the
model architecture by ourselves. Apache TVM prepares a PyTorch-liked API to construct the model
architecture. We can use the API to construct the model architecture.

.. GENERATED FROM PYTHON SOURCE LINES 57-74

.. code-block:: Python


    import dataclasses
    import enum
    import os
    from pathlib import Path
    from pprint import pprint

    from tvm_ffi import Shape

    import tvm
    from tvm import relax, te, tirx
    from tvm.relax import register_pipeline
    from tvm.relax.frontend import nn
    from tvm.relax.frontend.nn import Tensor, op
    from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache
    from tvm.s_tir import dlight








.. GENERATED FROM PYTHON SOURCE LINES 75-78

First, we need to define the model configuration. The configuration includes the key parameters
of the model, such as hidden size, intermediate size, etc. Here for convenience, we define a
constant config specially for the TinyLlama model.

.. GENERATED FROM PYTHON SOURCE LINES 78-99

.. code-block:: Python



    @dataclasses.dataclass
    class LlamaConfig:
        hidden_size: int = 2048
        intermediate_size: int = 5632
        num_attention_heads: int = 32
        num_hidden_layers: int = 22
        rms_norm_eps: float = 1e-05
        vocab_size: int = 32000
        rope_theta: int = 10000
        context_window_size: int = 2048
        prefill_chunk_size: int = 2048
        num_key_value_heads: int = 4
        head_dim: int = 64  # hidden_size // num_attention_heads


    dev = tvm.device("cuda", 0)
    target = tvm.target.Target.from_device(dev)









.. GENERATED FROM PYTHON SOURCE LINES 100-106

Next, we define the RoPE mode of the Paged KV cache. The RoPE mode is used to apply the
Relative Positional Encoding (RoPE) to the query and key tensors. The RoPE mode can be set to
`NONE`, `NORMAL`, or `INLINE`. If the RoPE mode is `NONE`, the KV cache will not apply RoPE to
the query and key tensors. If the RoPE mode is `NORMAL`, RoPE will be applied to the key tensor
before adding the key tensor to the cache. If the RoPE mode is `INLINE`, RoPE will be applied to
the query and key tensors in the attention kernel on-the-fly.

.. GENERATED FROM PYTHON SOURCE LINES 106-120

.. code-block:: Python



    class RopeMode(enum.IntEnum):
        """The RoPE mode of the Paged KV cache.
        If it is none, the KV cache will not apply RoPE to q and k.
        If it is normal, RoPE will be applied to k before adding k to cache.
        Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
        """

        NONE = 0
        NORMAL = 1
        INLINE = 2









.. GENERATED FROM PYTHON SOURCE LINES 121-140

Secondly, we define the model architecture. The model architecture consists of three parts:

- Embedding layer: The embedding layer converts the input token IDs to the hidden states.
- Decoder layers: The decoder layers are the core of the model. Each decoder layer consists of
  a self-attention layer and a feed-forward network (FFN) layer.
- Output layer: The output layer converts the hidden states to the logits.

First we define the FFN layer. Note that the following FFN layer is optimized implementation
where we fuse the gate and up projection into one kernel.
The naive implementation of FFN layer is: ``FFN(x) = down_proj(silu(gate(x)) * up(x))``
We could combine the ``gate`` and ``up`` projection into one kernel for better performance.
The optimized implementation is:

.. code-block:: python

  concat_x = gate_up(x)
  gate_x, up_x = split(concat_x, 2, axis=-1)
  FFN(x) = down_proj(silu(gate_x) * up_x)


.. GENERATED FROM PYTHON SOURCE LINES 140-158

.. code-block:: Python



    class LlamaFFN(nn.Module):
        def __init__(self, config: LlamaConfig):
            super().__init__()
            self.gate_up_proj = nn.Linear(
                in_features=config.hidden_size,
                out_features=2 * config.intermediate_size,
                bias=False,
            )
            self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

        def forward(self, x: Tensor):
            concat_x1_x2 = self.gate_up_proj(x)
            x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
            return self.down_proj(op.silu(x1) * x2)









.. GENERATED FROM PYTHON SOURCE LINES 159-172

Then we define the self-attention layer. The self-attention layer consists of three parts:

- QKV projection: The QKV projection converts the input hidden states to the query, key, and
  value tensors.
- Attention: The attention layer computes the attention scores and applies the softmax
  operation.
- Output projection: The output projection converts the attention output to the hidden states.

We perform optimizations on the different parts of the self-attention layer:

- QKV projection: We leverage the horizontal fusion on QKV projection and fuse them into one
  kernel.
- Attention: We leverage the horizontal fusion on attention and fuse the QKV projection and

.. GENERATED FROM PYTHON SOURCE LINES 172-204

.. code-block:: Python



    class LlamaAttention(nn.Module):  # pylint: disable=too-many-instance-attributes
        def __init__(self, config: LlamaConfig):
            self.head_dim = config.head_dim
            self.num_q_heads = config.num_attention_heads
            self.num_kv_heads = config.num_key_value_heads
            # horizontal fusion on QKV projection
            self.qkv_proj = nn.Linear(
                in_features=config.hidden_size,
                out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,
                bias=False,
            )
            self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)

        def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
            d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads
            b, s, _ = hidden_states.shape
            # QKV Projection
            qkv = self.qkv_proj(hidden_states)
            qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
            # Attention
            output = op.reshape(
                paged_kv_cache.attention_with_fused_qkv(
                    layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5
                ),
                (b, s, h_q * d),
            )
            # Output Projection
            return self.o_proj(output)









.. GENERATED FROM PYTHON SOURCE LINES 205-206

Finally, we define the model architecture with FFN and self-attention layers.

.. GENERATED FROM PYTHON SOURCE LINES 206-355

.. code-block:: Python



    class LlamaDecoderLayer(nn.Module):
        def __init__(self, config: LlamaConfig):
            rms_norm_eps = config.rms_norm_eps
            self.self_attn = LlamaAttention(config)
            self.mlp = LlamaFFN(config)
            self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
            self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)

        def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
            hidden_states += self.self_attn(
                self.input_layernorm(hidden_states), paged_kv_cache, layer_id
            )
            hidden_states += self.mlp(self.post_attention_layernorm(hidden_states))
            return hidden_states


    class LlamaModel(nn.Module):
        def __init__(self, config: LlamaConfig):
            assert config.hidden_size % config.num_attention_heads == 0
            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
            self.layers = nn.ModuleList(
                [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
            )
            self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)

        def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
            hidden_states = input_embed
            for layer_id, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
            hidden_states = self.norm(hidden_states)
            return hidden_states


    class LlamaForCausalLM(nn.Module):
        def __init__(self, config: LlamaConfig):
            self.model = LlamaModel(config)
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            self.num_hidden_layers = config.num_hidden_layers
            self.num_attention_heads = config.num_attention_heads
            self.num_key_value_heads = config.num_key_value_heads
            self.head_dim = config.head_dim
            self.hidden_size = config.hidden_size
            self.vocab_size = config.vocab_size
            self.rope_theta = config.rope_theta
            self.dtype = "float32"

        def to(self, dtype: str | None = None):
            super().to(dtype=dtype)
            if dtype is not None:
                self.dtype = dtype

        def embed(self, input_ids: Tensor):
            return self.model.embed_tokens(input_ids)

        def get_logits(self, hidden_states: Tensor):
            logits = self.lm_head(hidden_states)
            if logits.dtype != "float32":
                logits = logits.astype("float32")
            return logits

        def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
            def _index(x: te.Tensor):  # x[:-1,:]
                b, s, d = x.shape
                return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")

            hidden_states = self.model(input_embed, paged_kv_cache)
            hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
            logits = self.get_logits(hidden_states)
            return logits, paged_kv_cache

        def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
            hidden_states = self.model(input_embed, paged_kv_cache)
            logits = self.get_logits(hidden_states)
            return logits, paged_kv_cache

        def create_tir_paged_kv_cache(
            self,
            max_batch_size: tirx.Var,
            max_total_seq_len: tirx.Var,
            prefill_chunk_size: tirx.Var,
            page_size: tirx.Var,
        ) -> PagedKVCache:
            return TIRPagedKVCache(
                attn_kind="mha",
                max_batch_size=max_batch_size,
                max_total_seq_len=max_total_seq_len,
                prefill_chunk_size=prefill_chunk_size,
                page_size=page_size,
                support_sliding_window=0,
                layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]),
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                num_key_value_heads=self.num_key_value_heads,
                qk_head_dim=self.head_dim,
                v_head_dim=self.head_dim,
                mla_original_qk_head_dim=0,
                mla_original_v_head_dim=0,
                rope_mode=RopeMode.NORMAL,
                rope_scale=1,
                rope_theta=self.rope_theta,
                rope_scaling={},
                rope_ext_factors=tirx.IntImm("int64", 0),
                rotary_dim=self.head_dim,
                dtype=self.dtype,
                target=target,
                enable_disaggregation=False,
            )

        def get_default_spec(self):
            mod_spec = {
                "embed": {
                    "input_ids": nn.spec.Tensor(["seq_len"], "int32"),
                    "$": {
                        "param_mode": "packed",
                        "effect_mode": "none",
                    },
                },
                "prefill": {
                    "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype),
                    "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                    "$": {
                        "param_mode": "packed",
                        "effect_mode": "none",
                    },
                },
                "decode": {
                    "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),
                    "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                    "$": {
                        "param_mode": "packed",
                        "effect_mode": "none",
                    },
                },
                "create_tir_paged_kv_cache": {
                    "max_batch_size": int,
                    "max_total_seq_len": int,
                    "prefill_chunk_size": int,
                    "page_size": int,
                    "$": {
                        "param_mode": "none",
                        "effect_mode": "none",
                    },
                },
            }
            return nn.spec.ModuleSpec.from_raw(mod_spec, self)









.. GENERATED FROM PYTHON SOURCE LINES 356-360

Export the model to Relax IRModule
----------------------------------
After defining the model architecture, we can export the model to the Relax IRModule.
For demonstration, we only show the part of the model architecture. and parameters.

.. GENERATED FROM PYTHON SOURCE LINES 360-372

.. code-block:: Python


    model_config = LlamaConfig()
    model = LlamaForCausalLM(model_config)
    model.to("float16")
    mod, named_params = model.export_tvm(spec=model.get_default_spec())
    prefill_str = mod["prefill"].script()
    print(*prefill_str.split("\n")[3:20], sep="\n")  # Only show the first 10 lines for demonstration
    print("        ...")

    print("\nParameters:")
    pprint(named_params[:5])  # Only show the first 5 parameters for demonstration





.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    @R.function
    def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Any, packed_params: R.Tuple(R.Tensor((32000, 2048), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((32000, 2048), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32000), dtype="float32"), R.Any):
        seq_len = T.int64()
        R.func_attr({"num_input": 2})
        with R.dataflow():
            model_embed_tokens_weight1: R.Tensor((32000, 2048), dtype="float16") = packed_params[0]
            model_layers_0_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[2]
            model_layers_0_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[3]
            model_layers_0_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[4]
            model_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[5]
            model_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[6]
            model_layers_1_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[7]
            model_layers_1_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[8]
            model_layers_1_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[9]
            model_layers_1_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[10]
            ...

    Parameters:
    [('model.embed_tokens.weight', Tensor([32000, 2048], "float16")),
     ('model.layers.0.self_attn.qkv_proj.weight', Tensor([2560, 2048], "float16")),
     ('model.layers.0.self_attn.o_proj.weight', Tensor([2048, 2048], "float16")),
     ('model.layers.0.mlp.gate_up_proj.weight', Tensor([11264, 2048], "float16")),
     ('model.layers.0.mlp.down_proj.weight', Tensor([2048, 5632], "float16"))]




.. GENERATED FROM PYTHON SOURCE LINES 373-377

Define Optimization Pipeline
----------------------------
We define a series of optimization passes to optimize the model. The optimization pipeline
is designed specifically for the LLMs.

.. GENERATED FROM PYTHON SOURCE LINES 377-434

.. code-block:: Python



    @register_pipeline("opt_llm")
    def _pipeline(  # pylint: disable=too-many-arguments
        ext_mods: list[nn.ExternModule] | None = None,
    ):
        ext_mods = ext_mods or []

        @tvm.transform.module_pass(opt_level=0)
        def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
            seq = tvm.transform.Sequential(
                [
                    # Phase 1. Passes on high-level operator graph
                    # We can enable cublas for further optimization
                    relax.transform.FuseTransposeMatmul(),
                    # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
                    relax.transform.LegalizeOps(),
                    relax.transform.AnnotateTIROpPattern(),
                    relax.transform.FoldConstant(),
                    relax.transform.FuseOps(),
                    relax.transform.FuseTIR(),
                    # Phase 3. Passes on TIR
                    relax.transform.DeadCodeElimination(),
                    # Phase 4. Low-level Optimizations
                    dlight.ApplyDefaultSchedule(
                        dlight.gpu.Matmul(),
                        dlight.gpu.GEMV(),
                        dlight.gpu.Reduction(),
                        dlight.gpu.GeneralReduction(),
                        dlight.gpu.Fallback(),
                    ),
                    # Phase 5. Lowering to VM bytecode
                    relax.transform.RewriteDataflowReshape(),
                    relax.transform.ToNonDataflow(),
                    relax.transform.RemovePurityChecking(),
                    relax.transform.CallTIRRewrite(),
                    relax.transform.StaticPlanBlockMemory(),
                    relax.transform.RewriteCUDAGraph(),
                    relax.transform.LowerAllocTensor(),
                    relax.transform.KillAfterLastUse(),
                    relax.transform.LowerRuntimeBuiltin(),
                    relax.transform.VMShapeLower(),
                    relax.transform.AttachGlobalSymbol(),
                    relax.transform.AttachExternModules(ext_mods),
                ]
            )
            mod = seq(mod)
            return mod

        return _pipeline


    with target:
        ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
        vm = relax.VirtualMachine(ex, dev)









.. GENERATED FROM PYTHON SOURCE LINES 435-446

Prepare the model weights
-------------------------
We load the pre-trained weights from Hugging Face and prepare the model weights.
The pre-trained weights are stored in the Hugging Face format. We need to load the weights
and prepare the model parameters.

.. note::

  Note that we won't execute the following code in this tutorial because the pre-trained weights
  are not available in the CI environment.


.. GENERATED FROM PYTHON SOURCE LINES 446-497

.. code-block:: Python



    IS_IN_CI = os.getenv("CI", "") == "true"

    HF_WEIGHT_PATH = None
    # HF_WEIGHT_PATH = Path("/path/to/TinyLlama-1.1B-Chat-v1.0/")

    if not IS_IN_CI:
        import numpy as np
        import safetensors.torch
        import torch

        if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists():
            raise ValueError("Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.")

        # Torch format weights
        param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / "model.safetensors", device="cpu")
        # Numpy format weights
        param_dict = {
            k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy()
            for k, v in param_dict.items()
        }

        named_params = dict(named_params)
        for i in range(model_config.num_hidden_layers):
            # Add QKV in self attention
            attn = f"model.layers.{i}.self_attn"
            param_dict[f"{attn}.qkv_proj.weight"] = np.concatenate(
                [
                    param_dict.pop(f"{attn}.q_proj.weight"),  # Pop the old parameters to save memory
                    param_dict.pop(f"{attn}.k_proj.weight"),
                    param_dict.pop(f"{attn}.v_proj.weight"),
                ],
                axis=0,
            )
            # Add gates in MLP
            mlp = f"model.layers.{i}.mlp"
            param_dict[f"{mlp}.gate_up_proj.weight"] = np.concatenate(
                [
                    param_dict.pop(f"{mlp}.gate_proj.weight"),
                    param_dict.pop(f"{mlp}.up_proj.weight"),
                ],
                axis=0,
            )

        # Convert params into ndarray
        params = [
            tvm.runtime.tensor(param_dict[k].astype("float16"), device=dev) for k in named_params.keys()
        ]









.. GENERATED FROM PYTHON SOURCE LINES 498-504

Deploy the compiled model
-------------------------
After the model and weights are ready, we can deploy the compiled model on the target device.
The language models inference includes two steps: prefill and decode. The prefill step is
used to process the input tokens and store the KVCache. The decode step is used to generate
the token until the end token is generated.

.. GENERATED FROM PYTHON SOURCE LINES 507-514

Tokenization
~~~~~~~~~~~~
The first step is to tokenize the input prompt and embed the tokens into the hidden states.
The tokenization and embedding are the same as the original model. We use the HF tokenizer
to tokenize the input prompt and embed the tokens into the hidden states.
Note that different models require different tokenization and prompt format, please refer to
the model documentation for the correct tokenization and prompt format.

.. GENERATED FROM PYTHON SOURCE LINES 514-529

.. code-block:: Python



    if not IS_IN_CI:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH)
        messages = [
            {"role": "user", "content": "What's your name?"},
        ]
        prompt = tokenizer.apply_chat_template(messages)
        input_len = len(prompt)

        # Load prompt tokens into TVM ndarray on the target device
        tokens = tvm.runtime.tensor(np.array(prompt).astype("int32"), device=dev)








.. GENERATED FROM PYTHON SOURCE LINES 530-535

Create the KVCache
~~~~~~~~~~~~~~~~~~
Before starting the inference, we need to create the KVCache. The KVCache is used to store the
key and value tensors for the attention layer. Apache TVM provides a PagedKVCache to store the
key and value tensors. We create the PagedKVCache with the specified parameters.

.. GENERATED FROM PYTHON SOURCE LINES 535-545

.. code-block:: Python


    if not IS_IN_CI:
        kv_cache = vm["create_tir_paged_kv_cache"](
            Shape([1]),  # max_batch_size=1
            Shape([2048]),  # max_total_seq_len=2048
            Shape([2048]),  # prefill_chunk_size=2048
            Shape([16]),  # page_size=16
        )









.. GENERATED FROM PYTHON SOURCE LINES 546-550

Embedding
~~~~~~~~~
The next step is to embed the tokens into the hidden states. We use the `embed` function
compiled in the Relax IRModule to embed the tokens into the hidden states.

.. GENERATED FROM PYTHON SOURCE LINES 550-561

.. code-block:: Python


    nd_view_func = tvm.get_global_func("vm.builtin.reshape")


    def embed(tokens, params):
        _embed = vm["embed"](tokens, params)
        # Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]
        _embed = nd_view_func(_embed, Shape([1, _embed.shape[0], _embed.shape[1]]))
        return _embed









.. GENERATED FROM PYTHON SOURCE LINES 562-565

Prefill
~~~~~~~
Before running the forward pass, we first get some help functions for preparation.

.. GENERATED FROM PYTHON SOURCE LINES 565-570

.. code-block:: Python


    add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
    begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
    end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward")








.. GENERATED FROM PYTHON SOURCE LINES 571-574

As we are creating a new sequence, we need to call `add_sequence_func` to initialize
the request. Additionally, we need to call `begin_forward_func` to start the forward pass,
and `end_forward_func` to end the forward pass.

.. GENERATED FROM PYTHON SOURCE LINES 574-583

.. code-block:: Python


    if not IS_IN_CI:
        seq_id = 0
        add_sequence_func(kv_cache, seq_id)
        hidden_states = embed(tokens, params)
        begin_forward_func(kv_cache, Shape([seq_id]), Shape([input_len]))
        logits, kv_cache = vm["prefill"](hidden_states, kv_cache, params)
        end_forward_func(kv_cache)








.. GENERATED FROM PYTHON SOURCE LINES 584-590

Now we have the output logits from the prefill step. The logits are used to generate the token
via sampling. Let's sample the token from the logits.

In this tutorial, we simplify the sampling process and pick the token with the highest
probability. In practice, we should sample the token based on the probability distribution.
Also, to make the tutorial concise, we execute the sample process on CPU.

.. GENERATED FROM PYTHON SOURCE LINES 590-602

.. code-block:: Python



    def sample_token(logits):
        logits_np = logits.numpy()
        return np.argmax(logits_np)


    if not IS_IN_CI:
        last_token = sample_token(logits)
        output_tokens = [last_token]









.. GENERATED FROM PYTHON SOURCE LINES 603-608

Decode
~~~~~~
After the prefill step, we can start the decode step. The decode step is used to generate the
token until the end token is generated. We use the `decode` function compiled in the Relax
IRModule to generate the token.

.. GENERATED FROM PYTHON SOURCE LINES 608-623

.. code-block:: Python


    if not IS_IN_CI:
        print("The generated token:")

        while last_token != tokenizer.eos_token_id:
            tokens = tvm.runtime.tensor(np.array([last_token]).astype("int32"), device=dev)
            hidden_states = embed(tokens, params)
            begin_forward_func(kv_cache, Shape([seq_id]), Shape([1]))
            logits, kv_cache = vm["decode"](hidden_states, kv_cache, params)

            end_forward_func(kv_cache)
            last_token = sample_token(logits)
            output_tokens.append(last_token)

        print(tokenizer.decode(output_tokens))








.. _sphx_glr_download_how_to_tutorials_optimize_llm.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: optimize_llm.ipynb <optimize_llm.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: optimize_llm.py <optimize_llm.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: optimize_llm.zip <optimize_llm.zip>`
