模型推理Prefill和Decode全流程公式
本文以输入"hi, who are you"为例,详细拆解了vLLM推理的Prefill和Decode过程。Prefill阶段对6个输入token进行嵌入、位置编码和32层Transformer计算,其中KV Cache采用分块存储(块大小16),仅需1个物理块。Decode阶段则自回归生成新token,每次仅处理单个token并复用预填充的KV Cache。整个过程严格遵循因
我来用 “hi, who are you” 作为完整示例,通过严格的数学公式和文字描述,详细拆解 vLLM 的 Prefill 和 Decode 全过程。
一、基础设定
1.1 模型参数
| 参数 | 符号 | 数值(示例) |
|---|---|---|
| 词汇表大小 | VVV | 32,000 |
| 隐藏层维度 | ddd | 4,096 |
| 注意力头数 | hhh | 32 |
| 每头维度 | dh=d/hd_h = d/hdh=d/h | 128 |
| 层数 | LLL | 32 |
| KV Cache 块大小 | BBB | 16 tokens |
| 序列最大长度 | TmaxT_{max}Tmax | 2,048 |
1.2 输入编码
输入字符串:“hi, who are you”
经过 Tokenizer 编码为 token ID 序列:
x=[x1,x2,x3,x4,x5,x6]=[hi,,,who,are,you,?] \mathbf{x} = [x_1, x_2, x_3, x_4, x_5, x_6] = [\text{hi}, \text{,}, \text{who}, \text{are}, \text{you}, \text{?}] x=[x1,x2,x3,x4,x5,x6]=[hi,,,who,are,you,?]
序列长度 n=6n = 6n=6。
二、Prefill 阶段(预填充阶段)
2.1 输入嵌入
将离散的 token ID 映射为连续向量:
E=Embedding(x)∈Rn×d \mathbf{E} = \text{Embedding}(\mathbf{x}) \in \mathbb{R}^{n \times d} E=Embedding(x)∈Rn×d
其中第 iii 个 token 的嵌入为 ei∈Rd\mathbf{e}_i \in \mathbb{R}^{d}ei∈Rd,因此:
E=[e1⊤e2⊤⋮e6⊤]=[—e1——e2—⋮—e6—]6×4096 \mathbf{E} = \begin{bmatrix} \mathbf{e}_1^\top \\ \mathbf{e}_2^\top \\ \vdots \\ \mathbf{e}_6^\top \end{bmatrix} = \begin{bmatrix} — & \mathbf{e}_1 & — \\ — & \mathbf{e}_2 & — \\ & \vdots & \\ — & \mathbf{e}_6 & — \end{bmatrix}_{6 \times 4096} E=
e1⊤e2⊤⋮e6⊤
=
———e1e2⋮e6———
6×4096
加上位置编码 P∈Rn×d\mathbf{P} \in \mathbb{R}^{n \times d}P∈Rn×d:
H(0)=E+P \mathbf{H}^{(0)} = \mathbf{E} + \mathbf{P} H(0)=E+P
2.2 逐层 Transformer 计算(以第 lll 层为例)
2.2.1 线性投影生成 Q、K、V
Q(l)=H(l−1)WQ(l)⊤∈Rn×dK(l)=H(l−1)WK(l)⊤∈Rn×dV(l)=H(l−1)WV(l)⊤∈Rn×d \begin{aligned} \mathbf{Q}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_Q^{(l)\top} \in \mathbb{R}^{n \times d} \\ \mathbf{K}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_K^{(l)\top} \in \mathbb{R}^{n \times d} \\ \mathbf{V}^{(l)} &= \mathbf{H}^{(l-1)} \mathbf{W}_V^{(l)\top} \in \mathbb{R}^{n \times d} \end{aligned} Q(l)K(l)V(l)=H(l−1)WQ(l)⊤∈Rn×d=H(l−1)WK(l)⊤∈Rn×d=H(l−1)WV(l)⊤∈Rn×d
其中权重矩阵 WQ(l),WK(l),WV(l)∈Rd×d\mathbf{W}_Q^{(l)}, \mathbf{W}_K^{(l)}, \mathbf{W}_V^{(l)} \in \mathbb{R}^{d \times d}WQ(l),WK(l),WV(l)∈Rd×d。
reshape 为多头形式(h=32h=32h=32 个头):
Q(l)=[Q1(l)Q2(l)⋯Qh(l)],Qj(l)∈Rn×dh \mathbf{Q}^{(l)} = \begin{bmatrix} \mathbf{Q}_1^{(l)} & \mathbf{Q}_2^{(l)} & \cdots & \mathbf{Q}_h^{(l)} \end{bmatrix}, \quad \mathbf{Q}_j^{(l)} \in \mathbb{R}^{n \times d_h} Q(l)=[Q1(l)Q2(l)⋯Qh(l)],Qj(l)∈Rn×dh
2.2.2 KV Cache 存储(PagedAttention 核心)
vLLM 将 K(l)\mathbf{K}^{(l)}K(l) 和 V(l)\mathbf{V}^{(l)}V(l) 分块存储到非连续的物理内存块中。
块分配计算:
所需块数=⌈nB⌉=⌈616⌉=1 个块 \text{所需块数} = \left\lceil \frac{n}{B} \right\rceil = \left\lceil \frac{6}{16} \right\rceil = 1 \text{ 个块} 所需块数=⌈Bn⌉=⌈166⌉=1 个块
分配物理块 Blockphys(l)\text{Block}_{phys}^{(l)}Blockphys(l)(假设为系统空闲块 #42),建立块表(Block Table):
T(l)=[42] \mathcal{T}^{(l)} = [42] T(l)=[42]
将 KV 张量写入块 #42:
Block42.K=K(l)∈R6×d(实际存储为 6×4096)Block42.V=V(l)∈R6×d \begin{aligned} \text{Block}_{42}.\mathbf{K} &= \mathbf{K}^{(l)} \in \mathbb{R}^{6 \times d} \quad (\text{实际存储为 } 6 \times 4096) \\ \text{Block}_{42}.\mathbf{V} &= \mathbf{V}^{(l)} \in \mathbb{R}^{6 \times d} \end{aligned} Block42.KBlock42.V=K(l)∈R6×d(实际存储为 6×4096)=V(l)∈R6×d
注意:块容量为 16,当前只用了前 6 个位置,剩余 10 个位置预留给后续 decode 阶段追加。
2.2.3 注意力计算(因果掩码)
对每个头 j∈{1,…,h}j \in \{1, \dots, h\}j∈{1,…,h}:
Sj=Qj(l)Kj(l)⊤dh∈Rn×n \mathbf{S}_j = \frac{\mathbf{Q}_j^{(l)} \mathbf{K}_j^{(l)\top}}{\sqrt{d_h}} \in \mathbb{R}^{n \times n} Sj=dhQj(l)Kj(l)⊤∈Rn×n
应用因果掩码 M∈{0,−∞}n×n\mathbf{M} \in \{0, -\infty\}^{n \times n}M∈{0,−∞}n×n(下三角为 0,上三角为 −∞-\infty−∞):
Sjmasked=Sj+M \mathbf{S}_j^{masked} = \mathbf{S}_j + \mathbf{M} Sjmasked=Sj+M
其中掩码矩阵:
Mij={0if i≥j−∞if i<j \mathbf{M}_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} Mij={0−∞if i≥jif i<j
Softmax 归一化:
Aj=softmax(Sjmasked)∈Rn×n \mathbf{A}_j = \text{softmax}(\mathbf{S}_j^{masked}) \in \mathbb{R}^{n \times n} Aj=softmax(Sjmasked)∈Rn×n
注意力输出:
Oj(l)=AjVj(l)∈Rn×dh \mathbf{O}_j^{(l)} = \mathbf{A}_j \mathbf{V}_j^{(l)} \in \mathbb{R}^{n \times d_h} Oj(l)=AjVj(l)∈Rn×dh
拼接所有头:
O(l)=Concat[O1(l),…,Oh(l)]WO(l)⊤∈Rn×d \mathbf{O}^{(l)} = \text{Concat}[\mathbf{O}_1^{(l)}, \dots, \mathbf{O}_h^{(l)}] \mathbf{W}_O^{(l)\top} \in \mathbb{R}^{n \times d} O(l)=Concat[O1(l),…,Oh(l)]WO(l)⊤∈Rn×d
2.2.4 前馈网络与残差连接
H′(l)=LayerNorm(H(l−1)+O(l))F(l)=FFN(H′(l))=σ(H′(l)W1⊤)W2⊤H(l)=LayerNorm(H′(l)+F(l)) \begin{aligned} \mathbf{H}'^{(l)} &= \text{LayerNorm}(\mathbf{H}^{(l-1)} + \mathbf{O}^{(l)}) \\ \mathbf{F}^{(l)} &= \text{FFN}(\mathbf{H}'^{(l)}) = \sigma(\mathbf{H}'^{(l)} \mathbf{W}_1^\top) \mathbf{W}_2^\top \\ \mathbf{H}^{(l)} &= \text{LayerNorm}(\mathbf{H}'^{(l)} + \mathbf{F}^{(l)}) \end{aligned} H′(l)F(l)H(l)=LayerNorm(H(l−1)+O(l))=FFN(H′(l))=σ(H′(l)W1⊤)W2⊤=LayerNorm(H′(l)+F(l))
重复上述过程 L=32L=32L=32 层,最终得到:
H(L)∈R6×d \mathbf{H}^{(L)} \in \mathbb{R}^{6 \times d} H(L)∈R6×d
2.3 生成第一个 Token
取最后一个位置的隐藏状态:
hlast=H6,:(L)∈Rd \mathbf{h}_{last} = \mathbf{H}^{(L)}_{6,:} \in \mathbb{R}^{d} hlast=H6,:(L)∈Rd
计算 logits:
z=hlastWlm⊤∈RV \mathbf{z} = \mathbf{h}_{last} \mathbf{W}_{lm}^\top \in \mathbb{R}^{V} z=hlastWlm⊤∈RV
其中 Wlm∈RV×d\mathbf{W}_{lm} \in \mathbb{R}^{V \times d}Wlm∈RV×d 是语言模型头。
应用 softmax 得到概率分布:
p(y∣x)=softmax(z)∈RV,∑v=1Vpv=1 p(y | \mathbf{x}) = \text{softmax}(\mathbf{z}) \in \mathbb{R}^{V}, \quad \sum_{v=1}^{V} p_v = 1 p(y∣x)=softmax(z)∈RV,v=1∑Vpv=1
采样(以贪心为例):
y1=argmaxv pv y_1 = \arg\max_v \, p_v y1=argvmaxpv
假设生成:“I”(token ID 为 100)
三、Decode 阶段(解码阶段)
现在进入自回归生成。当前序列状态:
- 已生成 token:[x1,…,x6,y1]=[hi,,,who,are,you,?,I][x_1, \dots, x_6, y_1] = [\text{hi}, \text{,}, \text{who}, \text{are}, \text{you}, \text{?}, \text{I}][x1,…,x6,y1]=[hi,,,who,are,you,?,I]
- 当前长度:n1=7n_1 = 7n1=7
- 各层 KV Cache:块 #42 中存储了 6 个 token 的 KV,剩余空间 10
3.1 Decode Step 1:生成第二个 token
3.1.1 单 token 嵌入
输入新 token y1="I"y_1 = \text{"I"}y1="I":
ey1=Embedding(y1)∈Rd \mathbf{e}_{y_1} = \text{Embedding}(y_1) \in \mathbb{R}^{d} ey1=Embedding(y1)∈Rd
加上位置编码(位置 7):
hinput(0)=ey1+p7∈Rd \mathbf{h}_{input}^{(0)} = \mathbf{e}_{y_1} + \mathbf{p}_7 \in \mathbb{R}^{d} hinput(0)=ey1+p7∈Rd
注意:此时输入是单个向量,而非矩阵。
3.1.2 逐层计算(关键差异:复用 KV Cache)
第 lll 层计算:
步骤 A:生成新 token 的 Q、K、V
qnew(l)=hinput(l−1)WQ(l)⊤∈Rdknew(l)=hinput(l−1)WK(l)⊤∈Rdvnew(l)=hinput(l−1)WV(l)⊤∈Rd \begin{aligned} \mathbf{q}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_Q^{(l)\top} \in \mathbb{R}^{d} \\ \mathbf{k}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_K^{(l)\top} \in \mathbb{R}^{d} \\ \mathbf{v}_{new}^{(l)} &= \mathbf{h}_{input}^{(l-1)} \mathbf{W}_V^{(l)\top} \in \mathbb{R}^{d} \end{aligned} qnew(l)knew(l)vnew(l)=hinput(l−1)WQ(l)⊤∈Rd=hinput(l−1)WK(l)⊤∈Rd=hinput(l−1)WV(l)⊤∈Rd
reshape 为多头:
qnew,j(l)∈Rdh,knew,j(l)∈Rdh,vnew,j(l)∈Rdhfor j=1,…,h \mathbf{q}_{new,j}^{(l)} \in \mathbb{R}^{d_h}, \quad \mathbf{k}_{new,j}^{(l)} \in \mathbb{R}^{d_h}, \quad \mathbf{v}_{new,j}^{(l)} \in \mathbb{R}^{d_h} \quad \text{for } j=1,\dots,h qnew,j(l)∈Rdh,knew,j(l)∈Rdh,vnew,j(l)∈Rdhfor j=1,…,h
步骤 B:追加 KV 到 Cache
检查块 #42 是否有空间:已用 6,容量 16,有空间。
追加写入:
Block42.K[6,:]=knew(l)Block42.V[6,:]=vnew(l) \begin{aligned} \text{Block}_{42}.\mathbf{K}[6,:] &= \mathbf{k}_{new}^{(l)} \\ \text{Block}_{42}.\mathbf{V}[6,:] &= \mathbf{v}_{new}^{(l)} \end{aligned} Block42.K[6,:]Block42.V[6,:]=knew(l)=vnew(l)
现在块 #42 包含 7 个 KV 向量。
步骤 C:PagedAttention 计算(核心优化)
需要从块 #42 读取所有历史 KV(7 个向量)来计算注意力。
对每个头 jjj,构造完整的 K 和 V:
Kfull,j=[k1,j(l)⊤⋮k7,j(l)⊤]∈R7×dh,Vfull,j=[v1,j(l)⊤⋮v7,j(l)⊤]∈R7×dh \mathbf{K}_{full,j} = \begin{bmatrix} \mathbf{k}_{1,j}^{(l)\top} \\ \vdots \\ \mathbf{k}_{7,j}^{(l)\top} \end{bmatrix} \in \mathbb{R}^{7 \times d_h}, \quad \mathbf{V}_{full,j} = \begin{bmatrix} \mathbf{v}_{1,j}^{(l)\top} \\ \vdots \\ \mathbf{v}_{7,j}^{(l)\top} \end{bmatrix} \in \mathbb{R}^{7 \times d_h} Kfull,j=
k1,j(l)⊤⋮k7,j(l)⊤
∈R7×dh,Vfull,j=
v1,j(l)⊤⋮v7,j(l)⊤
∈R7×dh
注意力分数(单查询对多 key):
sj=qnew,j(l)Kfull,j⊤dh∈R7 \mathbf{s}_j = \frac{\mathbf{q}_{new,j}^{(l)} \mathbf{K}_{full,j}^\top}{\sqrt{d_h}} \in \mathbb{R}^{7} sj=dhqnew,j(l)Kfull,j⊤∈R7
Softmax:
aj=softmax(sj)∈R7,∑i=17aj,i=1 \mathbf{a}_j = \text{softmax}(\mathbf{s}_j) \in \mathbb{R}^{7}, \quad \sum_{i=1}^{7} a_{j,i} = 1 aj=softmax(sj)∈R7,i=1∑7aj,i=1
注意力输出:
onew,j(l)=ajVfull,j=∑i=17aj,i⋅vi,j(l)∈Rdh \mathbf{o}_{new,j}^{(l)} = \mathbf{a}_j \mathbf{V}_{full,j} = \sum_{i=1}^{7} a_{j,i} \cdot \mathbf{v}_{i,j}^{(l)} \in \mathbb{R}^{d_h} onew,j(l)=ajVfull,j=i=1∑7aj,i⋅vi,j(l)∈Rdh
拼接所有头并投影:
onew(l)=Concat[onew,1(l),…,onew,h(l)]WO(l)⊤∈Rd \mathbf{o}_{new}^{(l)} = \text{Concat}[\mathbf{o}_{new,1}^{(l)}, \dots, \mathbf{o}_{new,h}^{(l)}] \mathbf{W}_O^{(l)\top} \in \mathbb{R}^{d} onew(l)=Concat[onew,1(l),…,onew,h(l)]WO(l)⊤∈Rd
步骤 D:残差与 FFN
hinput′(l)=LayerNorm(hinput(l−1)+onew(l))f(l)=FFN(hinput′(l))hinput(l)=LayerNorm(hinput′(l)+f(l)) \begin{aligned} \mathbf{h}'^{(l)}_{input} &= \text{LayerNorm}(\mathbf{h}_{input}^{(l-1)} + \mathbf{o}_{new}^{(l)}) \\ \mathbf{f}^{(l)} &= \text{FFN}(\mathbf{h}'^{(l)}_{input}) \\ \mathbf{h}_{input}^{(l)} &= \text{LayerNorm}(\mathbf{h}'^{(l)}_{input} + \mathbf{f}^{(l)}) \end{aligned} hinput′(l)f(l)hinput(l)=LayerNorm(hinput(l−1)+onew(l))=FFN(hinput′(l))=LayerNorm(hinput′(l)+f(l))
经过 LLL 层后得到:
houtput=hinput(L)∈Rd \mathbf{h}_{output} = \mathbf{h}_{input}^{(L)} \in \mathbb{R}^{d} houtput=hinput(L)∈Rd
3.1.3 生成第二个 token
z=houtputWlm⊤∈RV \mathbf{z} = \mathbf{h}_{output} \mathbf{W}_{lm}^\top \in \mathbb{R}^{V} z=houtputWlm⊤∈RV
y2=argmaxv softmax(z)v y_2 = \arg\max_v \, \text{softmax}(\mathbf{z})_v y2=argvmaxsoftmax(z)v
假设生成:“am”
3.2 Decode Step 2:生成第三个 token
当前序列长度 n2=8n_2 = 8n2=8。
输入:y2="am"y_2 = \text{"am"}y2="am"
关键操作:
-
嵌入:ey2+p8\mathbf{e}_{y_2} + \mathbf{p}_8ey2+p8
-
KV Cache 追加:块 #42 已用 7,追加后变为 8
-
PagedAttention:从块 #42 读取 8 个 KV 向量
Kfull∈R8×d,Vfull∈R8×d \mathbf{K}_{full} \in \mathbb{R}^{8 \times d}, \quad \mathbf{V}_{full} \in \mathbb{R}^{8 \times d} Kfull∈R8×d,Vfull∈R8×d -
注意力计算:
s=qnewKfull⊤dh∈R8 \mathbf{s} = \frac{\mathbf{q}_{new} \mathbf{K}_{full}^\top}{\sqrt{d_h}} \in \mathbb{R}^{8} s=dhqnewKfull⊤∈R8
注意:查询始终是单个向量,key 序列随长度增长。
生成 y3="an"y_3 = \text{"an"}y3="an"(假设)
3.3 Decode Step 3:生成第四个 token
当前长度 n3=9n_3 = 9n3=9。
生成 y4="AI"y_4 = \text{"AI"}y4="AI"(假设)
3.4 Decode Step 4:生成第五个 token
当前长度 n4=10n_4 = 10n4=10。
生成 y5="assistant"y_5 = \text{"assistant"}y5="assistant"(假设)
3.5 Decode Step 5:遇到 EOS 停止
生成 y6=""y_6 = \text{""}y6=""(EOS token)
停止生成。
四、完整序列的数学表示
最终生成的完整序列:
y=[y1,y2,y3,y4,y5,y6]=[I,am,an,AI,assistant,] \mathbf{y} = [y_1, y_2, y_3, y_4, y_5, y_6] = [\text{I}, \text{am}, \text{an}, \text{AI}, \text{assistant}, \text{}] y=[y1,y2,y3,y4,y5,y6]=[I,am,an,AI,assistant,]
完整对话:
Input: "hi, who are you"→Output: "I am an AI assistant" \text{Input: "hi, who are you"} \rightarrow \text{Output: "I am an AI assistant"} Input: "hi, who are you"→Output: "I am an AI assistant"
五、关键公式对比总结
| 阶段 | Prefill | Decode Step ttt |
|---|---|---|
| 输入维度 | X∈Rn×d\mathbf{X} \in \mathbb{R}^{n \times d}X∈Rn×d | xnew∈Rd\mathbf{x}_{new} \in \mathbb{R}^{d}xnew∈Rd(单向量) |
| Q 维度 | Q∈Rn×d\mathbf{Q} \in \mathbb{R}^{n \times d}Q∈Rn×d | qnew∈Rd\mathbf{q}_{new} \in \mathbb{R}^{d}qnew∈Rd |
| K/V 计算 | 全序列计算 K,V∈Rn×d\mathbf{K}, \mathbf{V} \in \mathbb{R}^{n \times d}K,V∈Rn×d | 仅计算新 token:knew,vnew∈Rd\mathbf{k}_{new}, \mathbf{v}_{new} \in \mathbb{R}^{d}knew,vnew∈Rd |
| KV Cache 操作 | 创建并写入块 | 读取历史 + 追加新 KV |
| 注意力分数 | S∈Rn×n\mathbf{S} \in \mathbb{R}^{n \times n}S∈Rn×n(矩阵) | s∈Rn+t\mathbf{s} \in \mathbb{R}^{n+t}s∈Rn+t(向量) |
| 计算复杂度 | O(n2⋅d)O(n^2 \cdot d)O(n2⋅d) | O((n+t)⋅d)O((n+t) \cdot d)O((n+t)⋅d) per step |
| 内存访问 | 写入 nnn 个 KV | 读取 (n+t)(n+t)(n+t) 个 KV + 写入 1 个 KV |
六、PagedAttention 的内存布局可视化
物理内存块 #42(大小 B=16):
┌─────────────────────────────────────────────────────────────┐
│ 位置: 0 1 2 3 4 5 6 7 8 9 ... 15 │
├─────────────────────────────────────────────────────────────┤
│ Prefill: [hi] [,] [who][are][you][ ? ] │
│ │ │ │ │ │ │ │
│ K1 K2 K3 K4 K5 K6 ← 存储的 Key 向量 │
│ V1 V2 V3 V4 V5 V6 ← 存储的 Value 向量 │
├─────────────────────────────────────────────────────────────┤
│ Decode1: [ I ] │
│ │ │
│ K7 ← 追加 │
│ V7 ← 追加 │
├─────────────────────────────────────────────────────────────┤
│ Decode2: [am] │
│ │ │
│ K8 ← 追加 │
│ V8 ← 追加 │
├─────────────────────────────────────────────────────────────┤
│ ... 继续直到块满或生成结束 │
└─────────────────────────────────────────────────────────────┘
块表 Block Table: [42] (逻辑块 0 → 物理块 42)
当序列超过 16 个 token 时,系统会分配第二个物理块(如 #57),块表变为 [42,57][42, 57][42,57],注意力内核通过块表非连续读取 KV。
七、核心优化点总结
-
Prefill 的并行性:利用矩阵乘法一次性计算 n×nn \times nn×n 注意力矩阵,充分发挥 GPU 算力。
-
Decode 的 KV 复用:避免重复计算历史 token 的 K、V,将 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n) 每步。
-
分页内存管理:通过块表实现非连续存储,支持动态长度、内存共享和高效批处理。
-
内存带宽优化:PagedAttention 内核通过分块加载和共享内存,减少全局内存访问次数,缓解 Decode 阶段的带宽瓶颈。
昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐



所有评论(0)