Transformer Decoder 全流程

训练(并行)与推理(自回归 + KV Cache)的数据流与张量维度

🎓 一、训练时(并行输入整段序列)

记号约定

  • $B$ — batch size
  • $L$ — 序列长度
  • $d_{\text{model}}$ — 隐藏维度
  • $h$ — head 数
  • $d_h = d_{\text{model}} / h$ — 每头维度
  • $V$ — 词表大小
  • $d_{\text{ff}}$ — MLP 中间维度(通常 $4d_{\text{model}}$)
  • $T$ — 推理时已生成的 token 数
1

输入与嵌入

输入 token ids:

$\text{input\_ids} \in \mathbb{Z}^{B \times L}$

经过嵌入矩阵 $W_e \in \mathbb{R}^{V \times d_{\text{model}}}$(查表操作):

$X_0 = W_e[\text{input\_ids}] \in \mathbb{R}^{B \times L \times d_{\text{model}}}$
维度:$(B, L)$ 整数索引 $\xrightarrow{\text{lookup } W_e}$ $(B, L, d_{\text{model}})$ 浮点张量

老式绝对位置编码:$X_0 = W_e[\text{input\_ids}] + PE$;现代 LLM 常用 RoPE,此处通常不直接加 PE

2

一个 Decoder Block(以第 $i$ 层为例)

设本层输入 $X \in \mathbb{R}^{B \times L \times d_{\text{model}}}$

(1) Q/K/V 线性投影

通过三组可学习的参数矩阵做线性变换:

$Q = X\, W_Q, \quad K = X\, W_K, \quad V = X\, W_V$

$W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times (h \cdot d_h)}$ 为本层可学习参数(因 $h \cdot d_h = d_{\text{model}}$,实际就是 $d_{\text{model}} \times d_{\text{model}}$ 的方阵)

矩阵乘法维度:$\underbrace{(B, L, d_{\text{model}})}_{X} \times \underbrace{(d_{\text{model}}, h \cdot d_h)}_{W_Q} = \underbrace{(B, L, h \cdot d_h)}_{Q}$

reshape 为多头形式,再 transpose 把 head 维提前:

$(B, L, h \cdot d_h) \xrightarrow{\text{reshape}} (B, L, h, d_h) \xrightarrow{\text{transpose}} (B, h, L, d_h)$
$Q, K, V \in \mathbb{R}^{B \times h \times L \times d_h}$

(2) 位置编码(RoPE)

对 $Q, K$ 的每个位置 $\text{pos} = 0, 1, \ldots, L{-}1$ 施加旋转矩阵(逐元素,不改变维度):

$Q_r = \text{RoPE}(Q), \quad K_r = \text{RoPE}(K) \quad \in \mathbb{R}^{B \times h \times L \times d_h}$

RoPE 核心公式——将每两个维度视为 2D 向量,按位置做旋转:

$\begin{pmatrix} q_{2i}^{(t)} \\ q_{2i+1}^{(t)} \end{pmatrix} = \begin{pmatrix} \cos(t\theta_i) & -\sin(t\theta_i) \\ \sin(t\theta_i) & \cos(t\theta_i) \end{pmatrix} \begin{pmatrix} q_{2i} \\ q_{2i+1} \end{pmatrix}, \quad \theta_i = 10000^{-2i/d_h}$

关键性质:$Q_t^\top K_{t+\Delta} = (R_\theta^t q)^\top (R_\theta^{t+\Delta} k) = q^\top R_\theta^\Delta k$,内积只依赖相对位置差 $\Delta$,与绝对位置无关。$V$ 不做位置编码。

(3) 注意力分数

$Q_r$ 与 $K_r^\top$ 做矩阵乘法(在最后两维上):

$S = \dfrac{Q_r\, K_r^\top}{\sqrt{d_h}}$
$\underbrace{(B, h, L, d_h)}_{Q_r} \times \underbrace{(B, h, d_h, L)}_{K_r^\top} = \underbrace{(B, h, L, L)}_{S}$

$S$ 的含义:每个 head 上,$L$ 个 query 位置对 $L$ 个 key 位置的原始注意力分数。

加 causal mask(上三角置 $-\infty$)后 softmax:

$A = \text{softmax}(\text{mask}(S)) \in \mathbb{R}^{B \times h \times L \times L}$

Causal mask 形状(以 $L=5$ 为例):

    k₀  k₁  k₂  k₃  k₄
q₀  ✓  ·  ·  ·  ·
q₁  ✓  ✓  ·  ·  ·
q₂  ✓  ✓  ✓  ·  ·
q₃  ✓  ✓  ✓  ✓  ·
q₄  ✓  ✓  ✓  ✓  ✓

下三角(含对角线)保留原 score,上三角置 $-\infty$,softmax 后变 0——确保位置 $t$ 只能看到位置 $\le t$ 的 token。

(4) 加权求和

$O_{\text{head}} = A\, V$
$\underbrace{(B, h, L, L)}_{A} \times \underbrace{(B, h, L, d_h)}_{V} = \underbrace{(B, h, L, d_h)}_{O_{\text{head}}}$

每个 query 位置得到 $d_h$ 维的加权上下文向量。

(5) 多头拼接 + 输出投影

将 $h$ 个 head 沿最后一维拼接:

$(B, h, L, d_h) \xrightarrow{\text{transpose}} (B, L, h, d_h) \xrightarrow{\text{reshape}} (B, L,\, h \cdot d_h)$

经过输出投影矩阵 $W_O \in \mathbb{R}^{(h \cdot d_h) \times d_{\text{model}}}$:

$O_{\text{attn}} = O_{\text{concat}}\, W_O$
$\underbrace{(B, L, h \cdot d_h)}_{O_{\text{concat}}} \times \underbrace{(h \cdot d_h, d_{\text{model}})}_{W_O} = \underbrace{(B, L, d_{\text{model}})}_{O_{\text{attn}}}$

(6) Pre-Norm 残差结构 + SwiGLU MLP

RMSNorm(现代 LLM 取代 LayerNorm)—— 只保留缩放,去掉减均值:

$\text{RMSNorm}(x) = \dfrac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \varepsilon}} \odot \gamma$

$\gamma \in \mathbb{R}^d$ 为可学习缩放参数;$\varepsilon = 10^{-6}$ 防止除零。比 LayerNorm 少一次均值计算,更快且效果相当。

Pre-Norm 残差(现代 LLM 标准结构,与原始 Transformer 的 Post-Norm 不同):

$X_{\text{mid}} = X + \text{Attention}\bigl(\text{RMSNorm}(X)\bigr)$

注:上文 Step 2(1)-(5) 中的 $X$ 实际是 $\text{RMSNorm}(X_{\text{原始}})$,为简化记号没有显式写出。Pre-Norm 让深层网络梯度更稳定,无需 warm-up。

SwiGLU(LLaMA 等现代模型的激活方案)—— 用门控机制替代单一激活:

$\text{SwiGLU}(x) = \bigl(\text{SiLU}(x W_g)\bigr) \odot (x W_u) \cdot W_d$

$W_g, W_u \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$ 为门控/升维矩阵,$W_d \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}$ 为降维矩阵;$\text{SiLU}(z) = z \cdot \sigma(z)$ 为 Sigmoid-Linear Unit

$\underbrace{(B, L, d_{\text{model}})}_{x} \xrightarrow{W_g,\,W_u} \underbrace{(B, L, d_{\text{ff}})}_{\text{gate},\,\text{up}} \xrightarrow{\text{SiLU} \odot} (B, L, d_{\text{ff}}) \xrightarrow{W_d} \underbrace{(B, L, d_{\text{model}})}_{\text{output}}$

对比传统 MLP $\sigma(x W_1) W_2$:SwiGLU 用两个投影 $W_g, W_u$ 替代一个,参数量略多但表达能力更强。

第二个 Pre-Norm 残差(套在 SwiGLU 上):

$X' = X_{\text{mid}} + \text{SwiGLU}\bigl(\text{RMSNorm}(X_{\text{mid}})\bigr) \in \mathbb{R}^{B \times L \times d_{\text{model}}}$
Pre-Norm vs Post-Norm:原始 Transformer 论文是 $X' = \text{Norm}(X + \text{Sublayer}(X))$(Norm 在残差之后);现代 LLM 全部用 $X' = X + \text{Sublayer}(\text{Norm}(X))$(Norm 在子层之前)。后者训练更稳。
3

重复 $N$ 个 Decoder Block

$X_0 \to X_1 \to X_2 \to \cdots \to X_N$

每层的输入和输出 shape 恒为 $(B, L, d_{\text{model}})$,不变。

4

最终输出 logits

最后一层输出 $H \in \mathbb{R}^{B \times L \times d_{\text{model}}}$,经过 LM Head:

$Z = H\, W_{\text{lm}}$
$\underbrace{(B, L, d_{\text{model}})}_{H} \times \underbrace{(d_{\text{model}}, V)}_{W_{\text{lm}}} = \underbrace{(B, L, V)}_{Z}$

很多模型中 $W_{\text{lm}} = W_e^\top$(weight tying),无额外参数。

5

Loss

labels $\in \mathbb{Z}^{B \times L}$,对每个位置做词表大小 $V$ 的分类:

$\mathcal{L} = \text{CrossEntropy}(Z_{[:, :-1, :]},\ \text{labels}_{[:, 1:]})$

实际上是用位置 $t$ 的预测 $Z_t \in \mathbb{R}^V$ 去预测位置 $t{+}1$ 的 token。

$\text{input\_ids}\ (B,L) \;\xrightarrow{W_e}\; X_0\ (B,L,d) \;\xrightarrow{N \times \text{Block}}\; H\ (B,L,d) \;\xrightarrow{W_{\text{lm}}}\; Z\ (B,L,V) \;\to\; \text{CE loss}$

🚀 二、推理时(自回归 + KV Cache)

推理专用记号

  • $L_p$ — prompt 长度(用户输入的 token 数)
  • $t$ — 当前解码步的位置索引(从 $L_p$ 开始)
  • $T$ — cache 中已有的 token 数(位置 $0 \ldots T{-}1$ 的 K/V 已缓存)
  • 下标 $_t$ — "当前这一个 token" 的量(序列维 = 1)
  • 下标 $_{\text{all}}$ — "历史 cache + 当前 token" 拼接后的量(序列维 = $T{+}1$)
  • 下标 $_{\text{cache}}$ — "仅历史缓存" 的量(序列维 = $T$)
推理分两个阶段① Prefill(处理 prompt,一次性并行,下方 Step 0) → ② Decode(逐 token 自回归,下方 Step 1-4,每步一个 token)。 两阶段用同一份模型参数,差异只在输入长度和是否带 cache。
0

Prefill 阶段:预填充 prompt 到 KV Cache

用户给定 prompt:

$\text{prompt\_ids} \in \mathbb{Z}^{B \times L_p}$

这一步不是逐 token,而是把整段 prompt 一次性送进模型——shape 完全等同训练:

$(B, L_p) \xrightarrow{W_e} (B, L_p, d_{\text{model}}) \xrightarrow{N \times \text{Block (causal mask)}} (B, L_p, d_{\text{model}}) \xrightarrow{W_{\text{lm}}} (B, L_p, V)$

Prefill 阶段每层算出的 K/V 全部存入 cache

$K_{\text{cache}}^{(i)},\, V_{\text{cache}}^{(i)} \in \mathbb{R}^{B \times h \times L_p \times d_h}, \quad i = 1, \ldots, N$

从 $(B, L_p, V)$ 的 logits 中只取最后一个位置:

$Z_{[:, -1, :]} \in \mathbb{R}^{B \times V} \;\xrightarrow{\text{sample}}\; \text{token}_{L_p}$

这就是模型生成的第一个新 token。Prefill 完成后 $T = L_p$,进入 Decode 阶段。

Prefill 是计算密集型(一次 $L_p \times L_p$ 注意力),Decode 是显存带宽密集型(每步只算 $1 \times (T{+}1)$,但要从 HBM 读全部 KV cache)——这是工业部署中两阶段独立优化的根本原因。
1

Decode 阶段:当前输入

当前步只输入最新一个 token(不是整段序列):

$\text{input\_id}_t \in \mathbb{Z}^{B \times 1}$

经过嵌入矩阵 $W_e$(同训练时的同一个矩阵):

$X_t^{(0)} = W_e[\text{input\_id}_t] \in \mathbb{R}^{B \times 1 \times d_{\text{model}}}$
注意序列维 = $\mathbf{1}$(只有一个 token),而训练时序列维 = $L$
2

一个 Decoder Block(以第 $i$ 层为例)

设本层输入 $X_t \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$(只有当前 1 个 token)

(1) 当前 token 的 Q/K/V 投影

使用与训练时完全相同的参数矩阵 $W_Q, W_K, W_V$:

$Q_t = X_t\, W_Q, \quad K_t = X_t\, W_K, \quad V_t = X_t\, W_V$
$\underbrace{(B, \mathbf{1}, d_{\text{model}})}_{X_t} \times \underbrace{(d_{\text{model}}, h \cdot d_h)}_{W_Q} = \underbrace{(B, \mathbf{1}, h \cdot d_h)}_{Q_t}$

reshape + transpose 为多头形式:

$(B, \mathbf{1}, h \cdot d_h) \xrightarrow{\text{reshape}} (B, \mathbf{1}, h, d_h) \xrightarrow{\text{transpose}} (B, h, \mathbf{1}, d_h)$
$Q_t, K_t, V_t \in \mathbb{R}^{B \times h \times \mathbf{1} \times d_h}$
关键:序列维度 = 1,因为只投影了当前这一个 token。训练时这里是 $L$。

(2) RoPE

对当前位置 $t$(不是位置 0!)施加旋转位置编码:

$Q_t \leftarrow \text{RoPE}(Q_t,\ \text{pos}{=}t), \quad K_t \leftarrow \text{RoPE}(K_t,\ \text{pos}{=}t)$
shape 不变:$(B, h, \mathbf{1}, d_h)$

历史 cache 里的 $K$ 在之前各步已经各自 RoPE 过了(pos=0, 1, ..., T-1),无需重做。

(3) KV Cache 原理与拼接

为什么需要 KV Cache?

如果不用 cache,每生成一个 token 都要重新算全部 $T{+}1$ 个历史位置的 $K, V$,计算量随序列长度平方增长。KV Cache 的核心思想:历史位置的 $K, V$ 已经算过,存起来复用

$K_{\text{cache}}^{(t)} = \big[\, K^{(0)}, K^{(1)}, \ldots, K^{(t-1)} \,\big] \in \mathbb{R}^{B \times h \times t \times d_h}$

第 $t$ 步时,cache 中存放了位置 $0$ 到 $t{-}1$ 共 $t$ 个 token 的 Key(每步做完 RoPE 后立即存入)。

沿序列维(dim=2)拼接当前 token 的 $K_t, V_t$:

$K_{\text{all}} = \text{cat}(K_{\text{cache}},\, K_t)$, $\quad V_{\text{all}} = \text{cat}(V_{\text{cache}},\, V_t)$
$\underbrace{(B, h, T, d_h)}_{K_{\text{cache}}} \oplus \underbrace{(B, h, \mathbf{1}, d_h)}_{K_t} = \underbrace{(B, h, \mathbf{T{+}1}, d_h)}_{K_{\text{all}}}$
$Q_t$ 序列维仍是 1(只有当前 query);$K_{\text{all}}, V_{\text{all}}$ 序列维是 $T{+}1$(全部上下文)。这种不对称是 KV cache 的核心。

计算量对比(增量 vs 全量):

不用 cache:每步 $O((T{+}1)^2 \cdot d_h)$(全部 token 两两打分)
用 cache:每步 $O(1 \cdot (T{+}1) \cdot d_h)$(1 个 query $\times$ $T{+}1$ 个 key)

从 $O(T^2)$ 降到 $O(T)$,生成长文本时加速显著。代价是显存中要存 $2 \cdot N \cdot B \cdot h \cdot T \cdot d_h$ 个 float16 值($N$ 层,$K$ 和 $V$ 各一份)。

(4) 当前一步 Attention

注意力分数——当前 1 个 query 对全部 $T{+}1$ 个 key 打分:

$S_t = \dfrac{Q_t\, K_{\text{all}}^\top}{\sqrt{d_h}}$
$\underbrace{(B, h, \mathbf{1}, d_h)}_{Q_t} \times \underbrace{(B, h, d_h, \mathbf{T{+}1})}_{K_{\text{all}}^\top} = \underbrace{(B, h, \mathbf{1}, \mathbf{T{+}1})}_{S_t}$

对比训练时:$S \in \mathbb{R}^{B \times h \times L \times L}$($L$ 个 query $\times$ $L$ 个 key)。推理时只有 $1 \times (T{+}1)$。

softmax

$A_t = \text{softmax}(S_t) \in \mathbb{R}^{B \times h \times \mathbf{1} \times (T{+}1)}$

推理无需 causal mask(当前 token 本来就只能看到自己和之前的 token)。

加权求和——用注意力权重对全部 $T{+}1$ 个 value 加权:

$O_t^{\text{head}} = A_t\, V_{\text{all}}$
$\underbrace{(B, h, \mathbf{1}, T{+}1)}_{A_t} \times \underbrace{(B, h, T{+}1, d_h)}_{V_{\text{all}}} = \underbrace{(B, h, \mathbf{1}, d_h)}_{O_t^{\text{head}}}$
输出序列维仍是 1——一个 query 得到一个 $d_h$ 维的上下文向量,但它融合了全部 $T{+}1$ 个位置的信息。

(5) 多头拼接 + 输出投影

$(B, h, \mathbf{1}, d_h) \xrightarrow{\text{transpose}} (B, \mathbf{1}, h, d_h) \xrightarrow{\text{reshape}} (B, \mathbf{1},\, h \cdot d_h)$
$O_{t,\text{attn}} = O_{t,\text{concat}}\, W_O$
$\underbrace{(B, \mathbf{1}, h \cdot d_h)}_{O_{t,\text{concat}}} \times \underbrace{(h \cdot d_h, d_{\text{model}})}_{W_O} = \underbrace{(B, \mathbf{1}, d_{\text{model}})}_{O_{t,\text{attn}}}$

(6) Pre-Norm 残差 + SwiGLU(与训练完全一致,仅序列维 = 1)

$X_{t,\text{mid}} = X_t + \text{Attention}\bigl(\text{RMSNorm}(X_t)\bigr)$
SwiGLU:$(B, \mathbf{1}, d_{\text{model}}) \xrightarrow{W_g,W_u} (B, \mathbf{1}, d_{\text{ff}}) \xrightarrow{\text{SiLU}\odot} (B, \mathbf{1}, d_{\text{ff}}) \xrightarrow{W_d} (B, \mathbf{1}, d_{\text{model}})$
$X_t' = X_{t,\text{mid}} + \text{SwiGLU}\bigl(\text{RMSNorm}(X_{t,\text{mid}})\bigr) \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$

同时更新本层 cache:$K_{\text{cache}} \leftarrow K_{\text{all}}$,$V_{\text{cache}} \leftarrow V_{\text{all}}$(序列维从 $T$ 增长到 $T{+}1$)。

3

重复 $N$ 个 Decoder Block

$X_t^{(0)} \to X_t^{(1)} \to \cdots \to X_t^{(N)}$

每层的输入和输出 shape 恒为 $(B, \mathbf{1}, d_{\text{model}})$。

每层独立维护自己的 KV cache,共 $N$ 层,每层的 cache shape 为 $(B, h, T{+}1, d_h)$。

4

输出 logits + 采样策略

最终 hidden state $H_t \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$,经过 LM Head:

$Z_t = H_t\, W_{\text{lm}}$
$\underbrace{(B, \mathbf{1}, d_{\text{model}})}_{H_t} \times \underbrace{(d_{\text{model}}, V)}_{W_{\text{lm}}} = \underbrace{(B, \mathbf{1}, V)}_{Z_t}$

挤掉序列维得到 $z = Z_t[:, 0, :] \in \mathbb{R}^{B \times V}$,通过下面的流水线产出 next_token:

① Temperature:$z' = z / \tau$
    $\tau \to 0$ 等价 argmax(确定性);$\tau = 1$ 原始分布;$\tau \to \infty$ 趋近均匀分布
② Top-k:保留 $z'$ 中最大的 $k$ 个,其余置 $-\infty$
③ Top-p(nucleus):累积概率最先达到 $p$ 的最小集合,其余置 $-\infty$
④ Repetition penalty:已出现 token 的 logit 除以 $\rho > 1$(降低重复)
⑤ Softmax + 采样:$\text{next\_token} = \text{multinomial}(\text{softmax}(z'), 1)$

这个 next_token 就成为下一步的 $\text{input\_id}_{t+1}$,如此循环直到生成 EOS 或达到最大长度。

5

变体:GQA / MQA —— 压缩 KV Cache 的关键

标准 MHA 中 $h_q = h_{kv} = h$(Q/K/V head 数相同)。Grouped-Query Attention 让 K/V head 数远小于 Q head 数:

$h_q = h, \quad h_{kv} \ll h, \quad \text{每 } h/h_{kv} \text{ 个 Q head 共享一组 K/V head}$
项目MHA(标准)GQAMQA(极端)
$W_K, W_V$ 形状 $d_{\text{model}} \times h \cdot d_h$ $d_{\text{model}} \times h_{kv} \cdot d_h$ $d_{\text{model}} \times d_h$
K/V cache shape $(B, h, T, d_h)$ $(B, h_{kv}, T, d_h)$ $(B, 1, T, d_h)$
Cache 显存倍数 $h_{kv}/h$ × (如 8/64 = 1/8) $1/h$ × (如 1/64)
注意力计算前 把 K/V 沿 head 维 repeat $h/h_{kv}$ 倍后再算(数学等价于 broadcast)

实例:LLaMA 3 70B 用 GQA-8($h_q=64, h_{kv}=8$),KV cache 缩 8×;DeepSeek-V2 进一步用 MLA(Multi-head Latent Attention)压缩到 $\sim$1/16。这是长上下文部署的关键瓶颈优化。

6

KV Cache 显存核算(部署关键指标)

$\text{KV mem (bytes)} = 2 \cdot N \cdot B \cdot h_{kv} \cdot T \cdot d_h \cdot s$

系数 2 因为 K 和 V 各一份;$s$ = 每元素字节数(fp16/bf16 = 2,fp8 = 1,int4 = 0.5)。

实例 1(LLaMA 3 8B,MHA):$N{=}32, h{=}32, d_h{=}128, B{=}1, T{=}8192, s{=}2$

$2 \times 32 \times 1 \times 32 \times 8192 \times 128 \times 2 = \mathbf{4.3\ \text{GB}}$

实例 2(LLaMA 3 70B,GQA-8):$N{=}80, h_{kv}{=}8, d_h{=}128, B{=}1, T{=}8192, s{=}2$

$2 \times 80 \times 1 \times 8 \times 8192 \times 128 \times 2 = \mathbf{2.7\ \text{GB}}$
KV cache 与 $T$ 线性增长,且不分 batch 共享——这是长上下文(128k)和高并发推理的首要显存瓶颈,PageAttention/vLLM 等系统正是为它而生。
$\text{prompt}\ (B, L_p) \xrightarrow{\text{Prefill}} \text{cache}\ (T{=}L_p) + \text{token}_{L_p} \xrightarrow{\text{Decode loop}} \text{token}_{L_p+1}, \text{token}_{L_p+2}, \ldots$

⚖️ 三、训练 vs 推理:核心区别

训练 推理(第 $t$ 步)
输入 整段序列 $(B, L)$ 单个 token $(B, 1)$
Embedding 输出 $(B, L, d_{\text{model}})$ $(B, \mathbf{1}, d_{\text{model}})$
投影公式 $Q = X W_Q$ $Q_t = X_t W_Q$(同一个 $W_Q$)
$Q$ shape $(B, h, L, d_h)$
L 个 query
$(B, h, \mathbf{1}, d_h)$
1 个 query
$K, V$ shape $(B, h, L, d_h)$
当前投影得到
$K_t, V_t$: $(B, h, \mathbf{1}, d_h)$ 当前投影
$K_{\text{all}}, V_{\text{all}}$: $(B, h, \mathbf{T{+}1}, d_h)$ 拼接 cache 后
Attn Score $(B, h, L, L)$
$L$ query $\times$ $L$ key
$(B, h, \mathbf{1}, \mathbf{T{+}1})$
1 query $\times$ $(T{+}1)$ key
Attn 输出 $(B, h, L, d_h)$ $(B, h, \mathbf{1}, d_h)$
KV Cache 不需要 每层维护 $(B, h, T, d_h)$,每步追加 1 行
最终 logits $(B, L, V)$ $(B, \mathbf{1}, V)$
参数矩阵 $W_e, W_Q, W_K, W_V, W_O, W_1, W_2, W_{\text{lm}}$ 完全共享,推理不引入新参数

📦 四、全部可学习参数一览(每层)

参数 维度 作用位置
$W_e$ $V \times d_{\text{model}}$ Token 嵌入(查表)
$W_Q^{(i)}$ $d_{\text{model}} \times (h \cdot d_h)$ 第 $i$ 层 — 将隐状态投影为 Query
$W_K^{(i)}$ $d_{\text{model}} \times (h \cdot d_h)$ 第 $i$ 层 — 将隐状态投影为 Key
$W_V^{(i)}$ $d_{\text{model}} \times (h \cdot d_h)$ 第 $i$ 层 — 将隐状态投影为 Value
$W_O^{(i)}$ $(h \cdot d_h) \times d_{\text{model}}$ 第 $i$ 层 — 多头拼接后投影回 $d_{\text{model}}$
$W_1^{(i)}$ $d_{\text{model}} \times d_{\text{ff}}$ 第 $i$ 层 MLP — 升维到 $d_{\text{ff}}$
$W_2^{(i)}$ $d_{\text{ff}} \times d_{\text{model}}$ 第 $i$ 层 MLP — 降维回 $d_{\text{model}}$
$\gamma^{(i)}, \beta^{(i)}$ $d_{\text{model}}$ 各两组 第 $i$ 层 LayerNorm 缩放与偏移
$W_{\text{lm}}$ $d_{\text{model}} \times V$ LM Head(常与 $W_e^\top$ 共享)

共 $N$ 层。单层注意力参数 = $4 d_{\text{model}}^2$($W_Q + W_K + W_V + W_O$),单层 MLP 参数 = $2 \cdot d_{\text{model}} \cdot d_{\text{ff}} \approx 8 d_{\text{model}}^2$,总参数量约 $12 N d_{\text{model}}^2 + V \cdot d_{\text{model}}$。

训练时:$Q{=}XW_Q$ 一次算出 $L$ 个 query,attention score 为 $(L \times L)$,输出 $(B, L, V)$
推理时:$Q_t{=}X_t W_Q$ 只算 1 个 query,$K_{\text{all}}{=}\text{cat}(K_{\text{cache}}, K_t)$ 拼出 $T{+}1$ 个 key,
attention score 为 $(1 \times (T{+}1))$,输出 $(B, 1, V)$,历史 KV 不重算

💻 五、全套 PyTorch 实现(与公式一一对应)

以下代码为自包含的纯 PyTorch 实现,可直接运行。每个代码块均标注了对应的前文公式编号。

① RMSNorm(归一化) 对应:训练 Step 2(6) $\text{RMSNorm}(\cdot)$
class RMSNorm(nn.Module): # 公式: y = x / sqrt(mean(x²) + ε) * weight def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # x: [B, T, C] → rms: [B, T, 1] rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x * rms * self.weight
维度映射:输入 [B, T, C](即前文 $(B, L, d_{\text{model}})$)→ rms 沿最后一维求均值 → [B, T, 1],逐元素乘回原张量。比 LayerNorm 少了减去均值的步骤,更高效。
② SwiGLU(MLP 激活) 对应:训练 Step 2(6) $\text{MLP}(x) = \sigma(x W_1 + b_1) W_2 + b_2$
class SwiGLU(nn.Module): """SwiGLU(x) = SiLU(gate_proj(x)) * up_proj(x)""" def __init__(self, dim, hidden_dim): self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # 公式中 W₁ self.up_proj = nn.Linear(dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # 公式中 W₂ def forward(self, x): gate = self.gate_proj(x) # [B,T,C] → [B,T,d_ff] up = self.up_proj(x) # [B,T,C] → [B,T,d_ff] x = F.silu(gate) * up # SiLU(gate) × up,逐元素 x = self.down_proj(x) # [B,T,d_ff] → [B,T,C] return x
维度映射:[B, T, d_model] $\xrightarrow{W_1}$ [B, T, d_ff] $\xrightarrow{\sigma}$ [B, T, d_ff] $\xrightarrow{W_2}$ [B, T, d_model],与前文 MLP 维度链完全一致。现代 LLM(LLaMA 等)用 SwiGLU 替代 ReLU/GeLU。
③ RoPE 旋转位置编码 对应:训练 Step 2(2) $Q_r = \text{RoPE}(Q)$
def build_rope_cache(head_dim, seq_len, theta=10000.0, offset=0, device=None): # freq: [D/2] | pos: [T] | angle: [T, D/2] freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) pos = torch.arange(offset, offset + seq_len).float() angle = torch.outer(pos, freq) cos = angle.cos()[None, None, :, :] # [1, 1, T, D/2] sin = angle.sin()[None, None, :, :] return cos, sin def apply_rope(x, cos, sin): # x: [B,H,T,D] | cos,sin: [1,1,T,D/2] xe = x[..., 0::2] # 偶数维 [B,H,T,D/2] xo = x[..., 1::2] # 奇数维 [B,H,T,D/2] out_even = xe * cos - xo * sin # 2D 旋转矩阵的第一行 out_odd = xe * sin + xo * cos # 2D 旋转矩阵的第二行 out = torch.empty_like(x) out[..., 0::2] = out_even out[..., 1::2] = out_odd return out
对应公式:旋转矩阵 $R(\theta_t) = \begin{pmatrix}\cos\theta_t & -\sin\theta_t \\ \sin\theta_t & \cos\theta_t\end{pmatrix}$,每两个维度一组做 2D 旋转。offset 参数使推理时当前位置从 $t$ 开始(不是从 0),对应前文"推理 Step 2(2)"的 RoPE(pos=$t$)。
④ MultiHeadSelfAttention(QKV + RoPE + KV Cache + Attention) 对应:训练/推理 Step 2(1)-(5)
class MultiHeadSelfAttention(nn.Module): def __init__(self, dim, n_heads): self.n_heads = n_heads self.head_dim = dim // n_heads # 三组参数矩阵(训练/推理共享) self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.o_proj = nn.Linear(dim, dim, bias=False) def forward(self, x, attention_mask=None, past_kv=None, use_cache=False): B, T, C = x.shape # 1. QKV 投影: [B,T,C] × [C,C] → [B,T,C] q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) # 2. reshape + transpose → [B, H, T, D] q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # 3. RoPE(推理时 offset 使 pos 从 t 开始) past_len = 0 if past_kv is None else past_kv[0].shape[2] cos, sin = build_rope_cache(self.head_dim, T, offset=past_len) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) # 4. 拼接 KV Cache(仅推理时) if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=2) # [B,H,T+1,D] v = torch.cat([past_v, v], dim=2) present_kv = (k, v) if use_cache else None # 5. Attention 分数: [B,H,T,D] × [B,H,D,S] → [B,H,T,S] scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # 6. Causal Mask(仅训练时需要) if attention_mask is not None: scores = scores.masked_fill(~attention_mask, float("-inf")) # 7. Softmax + 8. 加权求和 attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v) # [B,H,T,D] # 9. merge heads + 10. 输出投影 out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.o_proj(out) # [B,T,C] return out, present_kv
训练 vs 推理维度对比:
训练:$T = L$(整段),$q\ [B,h,L,d_h] \times k^\top\ [B,h,d_h,L] \to$ scores [B,h,L,L],需 causal mask
推理:$T = 1$(单 token),$q_t\ [B,h,1,d_h] \times k_{\text{all}}^\top\ [B,h,d_h,T{+}1] \to$ scores [B,h,1,T{+}1],无需 mask
⑤ 一个完整的 Decoder Block 对应:训练 Step 2(6) 残差 + MLP 全流程
class TransformerBlock(nn.Module): def __init__(self, dim, n_heads, mlp_hidden_dim): self.attn_norm = RMSNorm(dim) self.attn = MultiHeadSelfAttention(dim, n_heads) self.ffn_norm = RMSNorm(dim) self.ffn = SwiGLU(dim, mlp_hidden_dim) def forward(self, x, attention_mask=None, past_kv=None, use_cache=False): # Pre-Norm 风格:先归一化,再 attention h = self.attn_norm(x) attn_out, present_kv = self.attn(h, attention_mask, past_kv, use_cache) x = x + attn_out # 残差连接 # Pre-Norm:先归一化,再 MLP h = self.ffn_norm(x) x = x + self.ffn(h) # 残差连接 return x, present_kv
对应公式链:LayerNorm(x) → Attn → x + attn_out → LayerNorm → MLP → x + mlp_out,与前文训练 Step 2(6) 的 $X_{\text{mid}} = \text{LayerNorm}(X + O_{\text{attn}})$ 完全一致。use_cache=True 时返回 $(K_{\text{all}}, V_{\text{all}})$ 供下一步使用。
⑥ 完整模型 TinyTransformer 对应:训练 Step 3-5 与 推理 Step 3-4
class TinyTransformer(nn.Module): def __init__(self, vocab_size, dim, n_heads, mlp_dim, n_layers): self.embed = nn.Embedding(vocab_size, dim) # W_e self.layers = nn.ModuleList([ # N 层 TransformerBlock(dim, n_heads, mlp_dim) for _ in range(n_layers) ]) self.norm = RMSNorm(dim) self.lm_head = nn.Linear(dim, vocab_size, bias=False) # W_lm def forward(self, input_ids, attention_mask=None, past_kvs=None, use_cache=False): x = self.embed(input_ids) # [B,T,V] → [B,T,C] ← Step 1 new_kvs = [] for layer, past_kv in zip(self.layers, past_kvs): x, present_kv = layer(x, attention_mask, past_kv, use_cache) # Step 2×N new_kvs.append(present_kv) x = self.norm(x) logits = self.lm_head(x) # [B,T,C] → [B,T,V] ← Step 4 return logits, new_kvs if use_cache else None
使用示例对比:
# 训练:整段序列 + causal mask
logits, _ = model(input_ids, attention_mask=causal_mask, use_cache=False)
# logits: [B, L, V] ← 对应训练 Step 4

# 推理:逐 token + KV cache
for step in range(4):
    logits, past_kvs = model(step_input, use_cache=True)
# logits: [B, 1, V] ← 对应推理 Step 4
⑦ 完整训练流程(Forward + Loss + Backward + Update) 对应:训练 Step 1-5 全流程
# ============ 1. 初始化 ============ model = TinyTransformer( vocab_size=32000, dim=4096, n_heads=32, mlp_dim=11008, n_layers=32, ).cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100000) # ============ 2. 训练循环 ============ for epoch in range(num_epochs): for batch in dataloader: input_ids = batch["input_ids"].cuda() # [B, L] labels = batch["labels"].cuda() # [B, L](= input_ids 左移1位) # --- Forward --- causal_mask = torch.tril(torch.ones(L, L, device="cuda")) # [L,L] logits, _ = model(input_ids, attention_mask=causal_mask, use_cache=False) # logits: [B, L, V] # --- Loss: 用 t 位置预测 t+1 位置的 token --- loss = F.cross_entropy( logits[:, :-1].reshape(-1, vocab_size), # [B*(L-1), V] labels[:, 1:].reshape(-1), # [B*(L-1)] ) # --- Backward --- optimizer.zero_grad() loss.backward() # 反向传播,填充 .grad torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # --- Update --- optimizer.step() # AdamW 更新参数 scheduler.step()
训练全流程拆解:
Embedding → N×(Attn+MLP) → LM Head → logits [B,L,V]CrossEntropybackward()AdamW.step()
关键细节:labels 是 input_ids 左移 1 位(位置 $t$ 预测位置 $t{+}1$),logits[:, :-1] 去掉最后一个位置(没有下一个 token 可预测)。
⑧ 推理生成流程(自回归 + KV Cache) 对应:推理 Step 1-4 全流程
# ============ 推理模式:逐步自回归生成 ============ model.eval() past_kvs = None generated = [bos_token_id] # 以 BOS 开头 for step in range(max_new_tokens): step_input = torch.tensor([generated[-1]]).cuda() # [1,1] 当前 token logits, past_kvs = model( step_input[None, :], # [B=1, T=1] past_kvs=past_kvs, use_cache=True ) # logits: [1, 1, V] next_token = torch.argmax(logits[:, -1], dim=-1).item() generated.append(next_token) if next_token == eos_token_id: break # past_kvs 逐层膨胀: 每层 (B,H,S,D),S 从 1 增长到 T # 总显存: 2 × N_layers × B × H × T × D × 2 bytes (fp16)
推理全流程拆解:
Embed 1 token → N×(Attn+KV Cache+MLP) → LM Head → logits [B,1,V]argmax/sample → 新 token → 循环
每步只算 1 个 token,历史 $K,V$ 从 cache 读取不重算。生成 $T$ 个 token 的总计算量 = $O(T^2 \cdot d \cdot h)$,比无 cache 的 $O(T^3 \cdot d \cdot h)$ 快一个数量级。