Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Optimize Large Language Model¶
As large language models (LLMs) have become a popular research topic in many different fields, deploying them on cloud and edge devices has become a challenging task. In this tutorial, we will demonstrate how to optimize a large language model using Apache TVM. We will use a pre-trained TinyLlama model from Hugging Face and deploy it on various devices.
Review Overall Flow¶
The overall flow consists of the following steps:
Construct or Import a Model: Construct a neural network model or import a pre-trained model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains all the information needed for compilation, including high-level Relax functions for computational graph, and low-level TensorIR functions for tensor program.
Perform Composable Optimizations: Perform a series of optimization transformations, such as graph optimizations, tensor program optimizations, and library dispatching.
Build and Universal Deployment: Build the optimized model to a deployable module to the universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.
Construct the model architecture¶
We will use a pre-trained TinyLlama model from Hugging Face. However, usually we only load the pre-trained weight from Hugging Face but not the model architecture. We need to construct the model architecture by ourselves. Apache TVM prepares a PyTorch-liked API to construct the model architecture. We can use the API to construct the model architecture.
import dataclasses
import enum
import os
from pathlib import Path
from pprint import pprint
from typing import List, Optional
import tvm
from tvm import dlight, relax, te, tir
from tvm.relax import register_pipeline
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache
from tvm.runtime import ShapeTuple
First, we need to define the model configuration. The configuration includes the key parameters of the model, such as hidden size, intermediate size, etc. Here for convenience, we define a constant config specially for the TinyLlama model.
@dataclasses.dataclass
class LlamaConfig:
hidden_size: int = 2048
intermediate_size: int = 5632
num_attention_heads: int = 32
num_hidden_layers: int = 22
rms_norm_eps: float = 1e-05
vocab_size: int = 32000
rope_theta: int = 10000
context_window_size: int = 2048
prefill_chunk_size: int = 2048
num_key_value_heads: int = 4
head_dim: int = 64 # hidden_size // num_attention_heads
dev = tvm.device("cuda", 0)
target = tvm.target.Target.from_device(dev)
Next, we define the RoPE mode of the Paged KV cache. The RoPE mode is used to apply the Relative Positional Encoding (RoPE) to the query and key tensors. The RoPE mode can be set to NONE, NORMAL, or INLINE. If the RoPE mode is NONE, the KV cache will not apply RoPE to the query and key tensors. If the RoPE mode is NORMAL, RoPE will be applied to the key tensor before adding the key tensor to the cache. If the RoPE mode is INLINE, RoPE will be applied to the query and key tensors in the attention kernel on-the-fly.
class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
If it is none, the KV cache will not apply RoPE to q and k.
If it is normal, RoPE will be applied to k before adding k to cache.
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
"""
NONE = 0
NORMAL = 1
INLINE = 2
Secondly, we define the model architecture. The model architecture consists of three parts:
Embedding layer: The embedding layer converts the input token IDs to the hidden states.
Decoder layers: The decoder layers are the core of the model. Each decoder layer consists of a self-attention layer and a feed-forward network (FFN) layer.
Output layer: The output layer converts the hidden states to the logits.
First we define the FFN layer. Note that the following FFN layer is optimized implementation
where we fuse the gate and up projection into one kernel.
The naive implementation of FFN layer is: FFN(x) = down_proj(silu(gate(x)) * up(x))
We could combine the gate
and up
projection into one kernel for better performance.
The optimized implementation is:
concat_x = gate_up(x)
gate_x, up_x = split(concat_x, 2, axis=-1)
FFN(x) = down_proj(silu(gate_x) * up_x)
class LlamaFFN(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.gate_up_proj = nn.Linear(
in_features=config.hidden_size,
out_features=2 * config.intermediate_size,
bias=False,
)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor):
concat_x1_x2 = self.gate_up_proj(x)
x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
return self.down_proj(op.silu(x1) * x2)
Then we define the self-attention layer. The self-attention layer consists of three parts:
QKV projection: The QKV projection converts the input hidden states to the query, key, and value tensors.
Attention: The attention layer computes the attention scores and applies the softmax operation.
Output projection: The output projection converts the attention output to the hidden states.
We perform optimizations on the different parts of the self-attention layer:
QKV projection: We leverage the horizontal fusion on QKV projection and fuse them into one kernel.
Attention: We leverage the horizontal fusion on attention and fuse the QKV projection and
class LlamaAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: LlamaConfig):
self.head_dim = config.head_dim
self.num_q_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
# horizontal fusion on QKV projection
self.qkv_proj = nn.Linear(
in_features=config.hidden_size,
out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,
bias=False,
)
self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads
b, s, _ = hidden_states.shape
# QKV Projection
qkv = self.qkv_proj(hidden_states)
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
# Attention
output = op.reshape(
paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads),
(b, s, h_q * d),
)
# Output Projection
return self.o_proj(output)
Finally, we define the model architecture with FFN and self-attention layers.
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
rms_norm_eps = config.rms_norm_eps
self.self_attn = LlamaAttention(config)
self.mlp = LlamaFFN(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
hidden_states += self.self_attn(
self.input_layernorm(hidden_states), paged_kv_cache, layer_id
)
hidden_states += self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
assert config.hidden_size % config.num_attention_heads == 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = input_embed
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCasualLM(nn.Module):
def __init__(self, config: LlamaConfig):
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.num_hidden_layers = config.num_hidden_layers
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
self.rope_theta = config.rope_theta
self.dtype = "float32"
def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
if dtype is not None:
self.dtype = dtype
def embed(self, input_ids: Tensor):
return self.model.embed_tokens(input_ids)
def get_logits(self, hidden_states: Tensor):
logits = self.lm_head(hidden_states)
if logits.dtype != "float32":
logits = logits.astype("float32")
return logits
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def _index(x: te.Tensor): # x[:-1,:]
b, s, d = x.shape
return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")
hidden_states = self.model(input_embed, paged_kv_cache)
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = self.model(input_embed, paged_kv_cache)
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def create_tir_paged_kv_cache(
self,
max_batch_size: tir.Var,
max_total_seq_len: tir.Var,
prefill_chunk_size: tir.Var,
page_size: tir.Var,
) -> PagedKVCache:
return TIRPagedKVCache(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
page_size=page_size,
support_sliding_window=0,
layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]),
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
rope_scaling={},
rope_ext_factors=relax.PrimValue(0),
rotary_dim=self.head_dim,
dtype=self.dtype,
target=target,
)
def get_default_spec(self):
mod_spec = {
"embed": {
"input_ids": nn.spec.Tensor(["seq_len"], "int32"),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"prefill": {
"input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"decode": {
"input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"create_tir_paged_kv_cache": {
"max_batch_size": int,
"max_total_seq_len": int,
"prefill_chunk_size": int,
"page_size": int,
"$": {
"param_mode": "none",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
Export the model to Relax IRModule¶
After defining the model architecture, we can export the model to the Relax IRModule. For demonstration, we only show the part of the model architecture. and parameters.
model_config = LlamaConfig()
model = LlamaForCasualLM(model_config)
model.to("float16")
mod, named_params = model.export_tvm(spec=model.get_default_spec())
prefill_str = mod["prefill"].script()
print(*prefill_str.split("\n")[3:20], sep="\n") # Only show the first 10 lines for demonstration
print(" ...")
print("\nParameters:")
pprint(named_params[:5]) # Only show the first 5 parameters for demonstration
@R.function
def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32000, 2048), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((32000, 2048), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32000), dtype="float32"), R.Object):
seq_len = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
model_embed_tokens_weight1: R.Tensor((32000, 2048), dtype="float16") = packed_params[0]
model_layers_0_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[1]
model_layers_0_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[2]
model_layers_0_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[3]
model_layers_0_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[4]
model_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[5]
model_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[6]
model_layers_1_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[7]
model_layers_1_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[8]
model_layers_1_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[9]
model_layers_1_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[10]
model_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[11]
...
Parameters:
[('model.embed_tokens.weight', Tensor([32000, 2048], "float16")),
('model.layers.0.self_attn.qkv_proj.weight', Tensor([2560, 2048], "float16")),
('model.layers.0.self_attn.o_proj.weight', Tensor([2048, 2048], "float16")),
('model.layers.0.mlp.gate_up_proj.weight', Tensor([11264, 2048], "float16")),
('model.layers.0.mlp.down_proj.weight', Tensor([2048, 5632], "float16"))]
Define Optimization Pipeline¶
We define a series of optimization passes to optimize the model. The optimization pipeline is designed specifically for the LLMs.
@register_pipeline("opt_llm")
def _pipeline( # pylint: disable=too-many-arguments
ext_mods: List[nn.ExternModule] = None,
):
ext_mods = ext_mods or []
@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 1. Passes on high-level operator graph
# We can enable cublas for further optimization
relax.transform.FuseTransposeMatmul(),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
relax.transform.LegalizeOps(),
relax.transform.AnnotateTIROpPattern(),
relax.transform.FoldConstant(),
relax.transform.FuseOps(),
relax.transform.FuseTIR(),
# Phase 3. Passes on TIR
relax.transform.DeadCodeElimination(),
# Phase 4. Low-level Optimizations
dlight.ApplyDefaultSchedule(
dlight.gpu.Matmul(),
dlight.gpu.GEMV(),
dlight.gpu.Reduction(),
dlight.gpu.GeneralReduction(),
dlight.gpu.Fallback(),
),
# Phase 5. Lowering to VM bytecode
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
relax.transform.CallTIRRewrite(),
relax.transform.StaticPlanBlockMemory(),
relax.transform.RewriteCUDAGraph(),
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
relax.transform.AttachExternModules(ext_mods),
]
)
mod = seq(mod)
return mod
return _pipeline
with target:
ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm"))
vm = relax.VirtualMachine(ex, dev)
Prepare the model weights¶
We load the pre-trained weights from Hugging Face and prepare the model weights. The pre-trained weights are stored in the Hugging Face format. We need to load the weights and prepare the model parameters.
Note
Note that we won’t execute the following code in this tutorial because the pre-trained weights are not available in the CI environment.
IS_IN_CI = os.getenv("CI", "") == "true"
HF_WEIGHT_PATH = None
# HF_WEIGHT_PATH = Path("/path/to/TinyLlama-1.1B-Chat-v1.0/")
if not IS_IN_CI:
import numpy as np
import safetensors.torch
import torch
if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists():
raise ValueError("Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.")
# Torch format weights
param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / "model.safetensors", device="cpu")
# Numpy format weights
param_dict = {
k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy()
for k, v in param_dict.items()
}
named_params = dict(named_params)
for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
param_dict[f"{attn}.qkv_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{attn}.q_proj.weight"), # Pop the old parameters to save memory
param_dict.pop(f"{attn}.k_proj.weight"),
param_dict.pop(f"{attn}.v_proj.weight"),
],
axis=0,
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
param_dict[f"{mlp}.gate_up_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{mlp}.gate_proj.weight"),
param_dict.pop(f"{mlp}.up_proj.weight"),
],
axis=0,
)
# Convert params into ndarray
params = [
tvm.nd.array(param_dict[k].astype("float16"), device=dev) for k in named_params.keys()
]
Deploy the compiled model¶
After the model and weights are ready, we can deploy the compiled model on the target device. The language models inference includes two steps: prefill and decode. The prefill step is used to process the input tokens and store the KVCache. The decode step is used to generate the token until the end token is generated.
Tokenization¶
The first step is to tokenize the input prompt and embed the tokens into the hidden states. The tokenization and embedding are the same as the original model. We use the HF tokenizer to tokenize the input prompt and embed the tokens into the hidden states. Note that different models require different tokenization and prompt format, please refer to the model documentation for the correct tokenization and prompt format.
if not IS_IN_CI:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH)
messages = [
{"role": "user", "content": "What's your name?"},
]
prompt = tokenizer.apply_chat_template(messages)
input_len = len(prompt)
# Load prompt tokens into TVM ndarray on the target device
tokens = tvm.nd.array(np.array(prompt).astype("int32"), device=dev)
Create the KVCache¶
Before starting the inference, we need to create the KVCache. The KVCache is used to store the key and value tensors for the attention layer. Apache TVM provides a PagedKVCache to store the key and value tensors. We create the PagedKVCache with the specified parameters.
Embedding¶
The next step is to embed the tokens into the hidden states. We use the embed function compiled in the Relax IRModule to embed the tokens into the hidden states.
nd_view_func = tvm.get_global_func("vm.builtin.reshape")
def embed(tokens, params):
_embed = vm["embed"](tokens, params)
# Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]
_embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0], _embed.shape[1]]))
return _embed
Prefill¶
Before running the forward pass, we first get some help functions for preparation.
add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward")
As we are creating a new sequence, we need to call add_sequence_func to initialize the request. Additionally, we need to call begin_forward_func to start the forward pass, and end_forward_func to end the forward pass.
Now we have the output logits from the prefill step. The logits are used to generate the token via sampling. Let’s sample the token from the logits.
In this tutorial, we simplify the sampling process and pick the token with the highest probability. In practice, we should sample the token based on the probability distribution. Also, to make the tutorial concise, we execute the sample process on CPU.
def sample_token(logits):
logits_np = logits.numpy()
return np.argmax(logits_np)
if not IS_IN_CI:
last_token = sample_token(logits)
output_tokens = [last_token]
Decode¶
After the prefill step, we can start the decode step. The decode step is used to generate the token until the end token is generated. We use the decode function compiled in the Relax IRModule to generate the token.
if not IS_IN_CI:
print("The generated token:")
while last_token != tokenizer.eos_token_id:
tokens = tvm.nd.array(np.array([last_token]).astype("int32"), device=dev)
hidden_states = embed(tokens, params)
begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1]))
logits, kv_cache = vm["decode"](hidden_states, kv_cache, params)
end_forward_func(kv_cache)
last_token = sample_token(logits)
output_tokens.append(last_token)
print(tokenizer.decode(output_tokens))
Total running time of the script: ( 1 minutes 10.324 seconds)