我来用 “hi, who are you” 作为完整示例,通过严格的数学公式和文字描述,详细拆解 vLLM 的 Prefill 和 Decode 全过程。


一、基础设定

1.1 模型参数

参数 符号 数值(示例)
词汇表大小 VVV 32,000
隐藏层维度 ddd 4,096
注意力头数 hhh 32
每头维度 dh=d/hd_h = d/hdh=d/h 128
层数 LLL 32
KV Cache 块大小 BBB 16 tokens
序列最大长度 TmaxT_{max}Tmax 2,048

1.2 输入编码

输入字符串:“hi, who are you”

经过 Tokenizer 编码为 token ID 序列:
x=[x1,x2,x3,x4,x5,x6]=[hi,,,who,are,you,?] \mathbf{x} = [x_1, x_2, x_3, x_4, x_5, x_6] = [\text{hi}, \text{,}, \text{who}, \text{are}, \text{you}, \text{?}] x=[x1,x2,x3,x4,x5,x6]=[hi,,,who,are,you,?]

序列长度 n=6n = 6n=6


二、Prefill 阶段(预填充阶段)

2.1 输入嵌入

将离散的 token ID 映射为连续向量:

E=Embedding(x)∈Rn×d \mathbf{E} = \text{Embedding}(\mathbf{x}) \in \mathbb{R}^{n \times d} E=Embedding(x)Rn×d

其中第 iii 个 token 的嵌入为 ei∈Rd\mathbf{e}_i \in \mathbb{R}^{d}eiRd,因此:
E=[e1⊤e2⊤⋮e6⊤]=[—e1——e2—⋮—e6—]6×4096 \mathbf{E} = \begin{bmatrix} \mathbf{e}_1^\top \\ \mathbf{e}_2^\top \\ \vdots \\ \mathbf{e}_6^\top \end{bmatrix} = \begin{bmatrix} — & \mathbf{e}_1 & — \\ — & \mathbf{e}_2 & — \\ & \vdots & \\ — & \mathbf{e}_6 & — \end{bmatrix}_{6 \times 4096} E= e1e2e6 = e1e2e6 6×4096

加上位置编码 P∈Rn×d\mathbf{P} \in \mathbb{R}^{n \times d}PRn×d
H(0)=E+P \mathbf{H}^{(0)} = \mathbf{E} + \mathbf{P} H(0)=E+P

2.2 逐层 Transformer 计算(以第 lll 层为例)

2.2.1 线性投影生成 Q、K、V

Q(l)=H(l−1)WQ(l)⊤∈Rn×dK(l)=H(l−1)WK(l)⊤∈Rn×dV(l)=H(l−1)WV(l)⊤∈Rn×d \begin{aligned} \mathbf{Q}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_Q^{(l)\top} \in \mathbb{R}^{n \times d} \\ \mathbf{K}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_K^{(l)\top} \in \mathbb{R}^{n \times d} \\ \mathbf{V}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_V^{(l)\top} \in \mathbb{R}^{n \times d} \end{aligned} Q(l)K(l)V(l)=H(l1)WQ(l)Rn×d=H(l1)WK(l)Rn×d=H(l1)WV(l)Rn×d

其中权重矩阵 WQ(l),WK(l),WV(l)∈Rd×d\mathbf{W}_Q^{(l)}, \mathbf{W}_K^{(l)}, \mathbf{W}_V^{(l)} \in \mathbb{R}^{d \times d}WQ(l),WK(l),WV(l)Rd×d

reshape 为多头形式(h=32h=32h=32 个头):
Q(l)=[Q1(l)Q2(l)⋯Qh(l)],Qj(l)∈Rn×dh \mathbf{Q}^{(l)} = \begin{bmatrix} \mathbf{Q}_1^{(l)} & \mathbf{Q}_2^{(l)} & \cdots & \mathbf{Q}_h^{(l)} \end{bmatrix}, \quad \mathbf{Q}_j^{(l)} \in \mathbb{R}^{n \times d_h} Q(l)=[Q1(l)Q2(l)Qh(l)],Qj(l)Rn×dh

2.2.2 KV Cache 存储(PagedAttention 核心)

vLLM 将 K(l)\mathbf{K}^{(l)}K(l)V(l)\mathbf{V}^{(l)}V(l) 分块存储到非连续的物理内存块中。

块分配计算:
所需块数=⌈nB⌉=⌈616⌉=1 个块 \text{所需块数} = \left\lceil \frac{n}{B} \right\rceil = \left\lceil \frac{6}{16} \right\rceil = 1 \text{ 个块} 所需块数=Bn=166=1 个块

分配物理块 Blockphys(l)\text{Block}_{phys}^{(l)}Blockphys(l)(假设为系统空闲块 #42),建立块表(Block Table)
T(l)=[42] \mathcal{T}^{(l)} = [42] T(l)=[42]

将 KV 张量写入块 #42:
Block42.K=K(l)∈R6×d(实际存储为 6×4096)Block42.V=V(l)∈R6×d \begin{aligned} \text{Block}_{42}.\mathbf{K} &= \mathbf{K}^{(l)} \in \mathbb{R}^{6 \times d} \quad (\text{实际存储为 } 6 \times 4096) \\ \text{Block}_{42}.\mathbf{V} &= \mathbf{V}^{(l)} \in \mathbb{R}^{6 \times d} \end{aligned} Block42.KBlock42.V=K(l)R6×d(实际存储为 6×4096)=V(l)R6×d

注意:块容量为 16,当前只用了前 6 个位置,剩余 10 个位置预留给后续 decode 阶段追加。

2.2.3 注意力计算(因果掩码)

对每个头 j∈{1,…,h}j \in \{1, \dots, h\}j{1,,h}

Sj=Qj(l)Kj(l)⊤dh∈Rn×n \mathbf{S}_j = \frac{\mathbf{Q}_j^{(l)} \mathbf{K}_j^{(l)\top}}{\sqrt{d_h}} \in \mathbb{R}^{n \times n} Sj=dh Qj(l)Kj(l)Rn×n

应用因果掩码 M∈{0,−∞}n×n\mathbf{M} \in \{0, -\infty\}^{n \times n}M{0,}n×n(下三角为 0,上三角为 −∞-\infty):
Sjmasked=Sj+M \mathbf{S}_j^{masked} = \mathbf{S}_j + \mathbf{M} Sjmasked=Sj+M

其中掩码矩阵:
Mij={0if i≥j−∞if i<j \mathbf{M}_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} Mij={0if ijif i<j

Softmax 归一化:
Aj=softmax(Sjmasked)∈Rn×n \mathbf{A}_j = \text{softmax}(\mathbf{S}_j^{masked}) \in \mathbb{R}^{n \times n} Aj=softmax(Sjmasked)Rn×n

注意力输出:
Oj(l)=AjVj(l)∈Rn×dh \mathbf{O}_j^{(l)} = \mathbf{A}_j \mathbf{V}_j^{(l)} \in \mathbb{R}^{n \times d_h} Oj(l)=AjVj(l)Rn×dh

拼接所有头:
O(l)=Concat[O1(l),…,Oh(l)]WO(l)⊤∈Rn×d \mathbf{O}^{(l)} = \text{Concat}[\mathbf{O}_1^{(l)}, \dots, \mathbf{O}_h^{(l)}] \mathbf{W}_O^{(l)\top} \in \mathbb{R}^{n \times d} O(l)=Concat[O1(l),,Oh(l)]WO(l)Rn×d

2.2.4 前馈网络与残差连接

H′(l)=LayerNorm(H(l−1)+O(l))F(l)=FFN(H′(l))=σ(H′(l)W1⊤)W2⊤H(l)=LayerNorm(H′(l)+F(l)) \begin{aligned} \mathbf{H}'^{(l)} &= \text{LayerNorm}(\mathbf{H}^{(l-1)} + \mathbf{O}^{(l)}) \\ \mathbf{F}^{(l)} &= \text{FFN}(\mathbf{H}'^{(l)}) = \sigma(\mathbf{H}'^{(l)} \mathbf{W}_1^\top) \mathbf{W}_2^\top \\ \mathbf{H}^{(l)} &= \text{LayerNorm}(\mathbf{H}'^{(l)} + \mathbf{F}^{(l)}) \end{aligned} H(l)F(l)H(l)=LayerNorm(H(l1)+O(l))=FFN(H(l))=σ(H(l)W1)W2=LayerNorm(H(l)+F(l))

重复上述过程 L=32L=32L=32 层,最终得到:
H(L)∈R6×d \mathbf{H}^{(L)} \in \mathbb{R}^{6 \times d} H(L)R6×d

2.3 生成第一个 Token

取最后一个位置的隐藏状态:
hlast=H6,:(L)∈Rd \mathbf{h}_{last} = \mathbf{H}^{(L)}_{6,:} \in \mathbb{R}^{d} hlast=H6,:(L)Rd

计算 logits:
z=hlastWlm⊤∈RV \mathbf{z} = \mathbf{h}_{last} \mathbf{W}_{lm}^\top \in \mathbb{R}^{V} z=hlastWlmRV

其中 Wlm∈RV×d\mathbf{W}_{lm} \in \mathbb{R}^{V \times d}WlmRV×d 是语言模型头。

应用 softmax 得到概率分布:
p(y∣x)=softmax(z)∈RV,∑v=1Vpv=1 p(y | \mathbf{x}) = \text{softmax}(\mathbf{z}) \in \mathbb{R}^{V}, \quad \sum_{v=1}^{V} p_v = 1 p(yx)=softmax(z)RV,v=1Vpv=1

采样(以贪心为例):
y1=arg⁡max⁡v pv y_1 = \arg\max_v \, p_v y1=argvmaxpv

假设生成:“I”(token ID 为 100)


三、Decode 阶段(解码阶段)

现在进入自回归生成。当前序列状态:

  • 已生成 token:[x1,…,x6,y1]=[hi,,,who,are,you,?,I][x_1, \dots, x_6, y_1] = [\text{hi}, \text{,}, \text{who}, \text{are}, \text{you}, \text{?}, \text{I}][x1,,x6,y1]=[hi,,,who,are,you,?,I]
  • 当前长度:n1=7n_1 = 7n1=7
  • 各层 KV Cache:块 #42 中存储了 6 个 token 的 KV,剩余空间 10

3.1 Decode Step 1:生成第二个 token

3.1.1 单 token 嵌入

输入新 token y1="I"y_1 = \text{"I"}y1="I"
ey1=Embedding(y1)∈Rd \mathbf{e}_{y_1} = \text{Embedding}(y_1) \in \mathbb{R}^{d} ey1=Embedding(y1)Rd

加上位置编码(位置 7):
hinput(0)=ey1+p7∈Rd \mathbf{h}_{input}^{(0)} = \mathbf{e}_{y_1} + \mathbf{p}_7 \in \mathbb{R}^{d} hinput(0)=ey1+p7Rd

注意:此时输入是单个向量,而非矩阵。

3.1.2 逐层计算(关键差异:复用 KV Cache)

lll 层计算:

步骤 A:生成新 token 的 Q、K、V
qnew(l)=hinput(l−1)WQ(l)⊤∈Rdknew(l)=hinput(l−1)WK(l)⊤∈Rdvnew(l)=hinput(l−1)WV(l)⊤∈Rd \begin{aligned} \mathbf{q}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_Q^{(l)\top} \in \mathbb{R}^{d} \\ \mathbf{k}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_K^{(l)\top} \in \mathbb{R}^{d} \\ \mathbf{v}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_V^{(l)\top} \in \mathbb{R}^{d} \end{aligned} qnew(l)knew(l)vnew(l)=hinput(l1)WQ(l)Rd=hinput(l1)WK(l)Rd=hinput(l1)WV(l)Rd

reshape 为多头:
qnew,j(l)∈Rdh,knew,j(l)∈Rdh,vnew,j(l)∈Rdhfor j=1,…,h \mathbf{q}_{new,j}^{(l)} \in \mathbb{R}^{d_h}, \quad \mathbf{k}_{new,j}^{(l)} \in \mathbb{R}^{d_h}, \quad \mathbf{v}_{new,j}^{(l)} \in \mathbb{R}^{d_h} \quad \text{for } j=1,\dots,h qnew,j(l)Rdh,knew,j(l)Rdh,vnew,j(l)Rdhfor j=1,,h

步骤 B:追加 KV 到 Cache

检查块 #42 是否有空间:已用 6,容量 16,有空间。

追加写入:
Block42.K[6,:]=knew(l)Block42.V[6,:]=vnew(l) \begin{aligned} \text{Block}_{42}.\mathbf{K}[6,:] &= \mathbf{k}_{new}^{(l)} \\ \text{Block}_{42}.\mathbf{V}[6,:] &= \mathbf{v}_{new}^{(l)} \end{aligned} Block42.K[6,:]Block42.V[6,:]=knew(l)=vnew(l)

现在块 #42 包含 7 个 KV 向量。

步骤 C:PagedAttention 计算(核心优化)

需要从块 #42 读取所有历史 KV(7 个向量)来计算注意力。

对每个头 jjj,构造完整的 K 和 V:
Kfull,j=[k1,j(l)⊤⋮k7,j(l)⊤]∈R7×dh,Vfull,j=[v1,j(l)⊤⋮v7,j(l)⊤]∈R7×dh \mathbf{K}_{full,j} = \begin{bmatrix} \mathbf{k}_{1,j}^{(l)\top} \\ \vdots \\ \mathbf{k}_{7,j}^{(l)\top} \end{bmatrix} \in \mathbb{R}^{7 \times d_h}, \quad \mathbf{V}_{full,j} = \begin{bmatrix} \mathbf{v}_{1,j}^{(l)\top} \\ \vdots \\ \mathbf{v}_{7,j}^{(l)\top} \end{bmatrix} \in \mathbb{R}^{7 \times d_h} Kfull,j= k1,j(l)k7,j(l) R7×dh,Vfull,j= v1,j(l)v7,j(l) R7×dh

注意力分数(单查询对多 key):
sj=qnew,j(l)Kfull,j⊤dh∈R7 \mathbf{s}_j = \frac{\mathbf{q}_{new,j}^{(l)} \mathbf{K}_{full,j}^\top}{\sqrt{d_h}} \in \mathbb{R}^{7} sj=dh qnew,j(l)Kfull,jR7

Softmax:
aj=softmax(sj)∈R7,∑i=17aj,i=1 \mathbf{a}_j = \text{softmax}(\mathbf{s}_j) \in \mathbb{R}^{7}, \quad \sum_{i=1}^{7} a_{j,i} = 1 aj=softmax(sj)R7,i=17aj,i=1

注意力输出:
onew,j(l)=ajVfull,j=∑i=17aj,i⋅vi,j(l)∈Rdh \mathbf{o}_{new,j}^{(l)} = \mathbf{a}_j \mathbf{V}_{full,j} = \sum_{i=1}^{7} a_{j,i} \cdot \mathbf{v}_{i,j}^{(l)} \in \mathbb{R}^{d_h} onew,j(l)=ajVfull,j=i=17aj,ivi,j(l)Rdh

拼接所有头并投影:
onew(l)=Concat[onew,1(l),…,onew,h(l)]WO(l)⊤∈Rd \mathbf{o}_{new}^{(l)} = \text{Concat}[\mathbf{o}_{new,1}^{(l)}, \dots, \mathbf{o}_{new,h}^{(l)}] \mathbf{W}_O^{(l)\top} \in \mathbb{R}^{d} onew(l)=Concat[onew,1(l),,onew,h(l)]WO(l)Rd

步骤 D:残差与 FFN
hinput′(l)=LayerNorm(hinput(l−1)+onew(l))f(l)=FFN(hinput′(l))hinput(l)=LayerNorm(hinput′(l)+f(l)) \begin{aligned} \mathbf{h}'^{(l)}_{input} &= \text{LayerNorm}(\mathbf{h}_{input}^{(l-1)} + \mathbf{o}_{new}^{(l)}) \\ \mathbf{f}^{(l)} &= \text{FFN}(\mathbf{h}'^{(l)}_{input}) \\ \mathbf{h}_{input}^{(l)} &= \text{LayerNorm}(\mathbf{h}'^{(l)}_{input} + \mathbf{f}^{(l)}) \end{aligned} hinput(l)f(l)hinput(l)=LayerNorm(hinput(l1)+onew(l))=FFN(hinput(l))=LayerNorm(hinput(l)+f(l))

经过 LLL 层后得到:
houtput=hinput(L)∈Rd \mathbf{h}_{output} = \mathbf{h}_{input}^{(L)} \in \mathbb{R}^{d} houtput=hinput(L)Rd

3.1.3 生成第二个 token

z=houtputWlm⊤∈RV \mathbf{z} = \mathbf{h}_{output} \mathbf{W}_{lm}^\top \in \mathbb{R}^{V} z=houtputWlmRV

y2=arg⁡max⁡v softmax(z)v y_2 = \arg\max_v \, \text{softmax}(\mathbf{z})_v y2=argvmaxsoftmax(z)v

假设生成:“am”


3.2 Decode Step 2:生成第三个 token

当前序列长度 n2=8n_2 = 8n2=8

输入:y2="am"y_2 = \text{"am"}y2="am"

关键操作:

  1. 嵌入ey2+p8\mathbf{e}_{y_2} + \mathbf{p}_8ey2+p8

  2. KV Cache 追加:块 #42 已用 7,追加后变为 8

  3. PagedAttention:从块 #42 读取 8 个 KV 向量
    Kfull∈R8×d,Vfull∈R8×d \mathbf{K}_{full} \in \mathbb{R}^{8 \times d}, \quad \mathbf{V}_{full} \in \mathbb{R}^{8 \times d} KfullR8×d,VfullR8×d

  4. 注意力计算
    s=qnewKfull⊤dh∈R8 \mathbf{s} = \frac{\mathbf{q}_{new} \mathbf{K}_{full}^\top}{\sqrt{d_h}} \in \mathbb{R}^{8} s=dh qnewKfullR8
    注意:查询始终是单个向量,key 序列随长度增长。

生成 y3="an"y_3 = \text{"an"}y3="an"(假设)


3.3 Decode Step 3:生成第四个 token

当前长度 n3=9n_3 = 9n3=9

生成 y4="AI"y_4 = \text{"AI"}y4="AI"(假设)


3.4 Decode Step 4:生成第五个 token

当前长度 n4=10n_4 = 10n4=10

生成 y5="assistant"y_5 = \text{"assistant"}y5="assistant"(假设)


3.5 Decode Step 5:遇到 EOS 停止

生成 y6=""y_6 = \text{""}y6=""(EOS token)

停止生成。


四、完整序列的数学表示

最终生成的完整序列:
y=[y1,y2,y3,y4,y5,y6]=[I,am,an,AI,assistant,] \mathbf{y} = [y_1, y_2, y_3, y_4, y_5, y_6] = [\text{I}, \text{am}, \text{an}, \text{AI}, \text{assistant}, \text{}] y=[y1,y2,y3,y4,y5,y6]=[I,am,an,AI,assistant,]

完整对话:
Input: "hi, who are you"→Output: "I am an AI assistant" \text{Input: "hi, who are you"} \rightarrow \text{Output: "I am an AI assistant"} Input: "hi, who are you"Output: "I am an AI assistant"


五、关键公式对比总结

阶段 Prefill Decode Step ttt
输入维度 X∈Rn×d\mathbf{X} \in \mathbb{R}^{n \times d}XRn×d xnew∈Rd\mathbf{x}_{new} \in \mathbb{R}^{d}xnewRd(单向量)
Q 维度 Q∈Rn×d\mathbf{Q} \in \mathbb{R}^{n \times d}QRn×d qnew∈Rd\mathbf{q}_{new} \in \mathbb{R}^{d}qnewRd
K/V 计算 全序列计算 K,V∈Rn×d\mathbf{K}, \mathbf{V} \in \mathbb{R}^{n \times d}K,VRn×d 仅计算新 token:knew,vnew∈Rd\mathbf{k}_{new}, \mathbf{v}_{new} \in \mathbb{R}^{d}knew,vnewRd
KV Cache 操作 创建并写入块 读取历史 + 追加新 KV
注意力分数 S∈Rn×n\mathbf{S} \in \mathbb{R}^{n \times n}SRn×n(矩阵) s∈Rn+t\mathbf{s} \in \mathbb{R}^{n+t}sRn+t(向量)
计算复杂度 O(n2⋅d)O(n^2 \cdot d)O(n2d) O((n+t)⋅d)O((n+t) \cdot d)O((n+t)d) per step
内存访问 写入 nnn 个 KV 读取 (n+t)(n+t)(n+t) 个 KV + 写入 1 个 KV

六、PagedAttention 的内存布局可视化

物理内存块 #42(大小 B=16):
┌─────────────────────────────────────────────────────────────┐
│  位置:  0    1    2    3    4    5    6    7    8    9  ... 15 │
├─────────────────────────────────────────────────────────────┤
│ Prefill: [hi] [,] [who][are][you][ ? ]                      │
│          │    │    │    │    │    │                         │
│          K1   K2   K3   K4   K5   K6   ← 存储的 Key 向量      │
│          V1   V2   V3   V4   V5   V6   ← 存储的 Value 向量    │
├─────────────────────────────────────────────────────────────┤
│ Decode1:                      [ I ]                         │
│                                │                            │
│                               K7   ← 追加                     │
│                               V7   ← 追加                     │
├─────────────────────────────────────────────────────────────┤
│ Decode2:                           [am]                     │
│                                     │                       │
│                                    K8  ← 追加                 │
│                                    V8  ← 追加                 │
├─────────────────────────────────────────────────────────────┤
│ ... 继续直到块满或生成结束                                       │
└─────────────────────────────────────────────────────────────┘

块表 Block Table: [42]  (逻辑块 0 → 物理块 42)

当序列超过 16 个 token 时,系统会分配第二个物理块(如 #57),块表变为 [42,57][42, 57][42,57],注意力内核通过块表非连续读取 KV。


七、核心优化点总结

  1. Prefill 的并行性:利用矩阵乘法一次性计算 n×nn \times nn×n 注意力矩阵,充分发挥 GPU 算力。

  2. Decode 的 KV 复用:避免重复计算历史 token 的 K、V,将 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n) 每步。

  3. 分页内存管理:通过块表实现非连续存储,支持动态长度、内存共享和高效批处理。

  4. 内存带宽优化:PagedAttention 内核通过分块加载和共享内存,减少全局内存访问次数,缓解 Decode 阶段的带宽瓶颈。

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐