🎓 一、训练时(并行输入整段序列)
记号约定
- $B$ — batch size
- $L$ — 序列长度
- $d_{\text{model}}$ — 隐藏维度
- $h$ — head 数
- $d_h = d_{\text{model}} / h$ — 每头维度
- $V$ — 词表大小
- $d_{\text{ff}}$ — MLP 中间维度(通常 $4d_{\text{model}}$)
- $T$ — 推理时已生成的 token 数
输入 token ids:
$\text{input\_ids} \in \mathbb{Z}^{B \times L}$
经过嵌入矩阵 $W_e \in \mathbb{R}^{V \times d_{\text{model}}}$(查表操作):
$X_0 = W_e[\text{input\_ids}] \in \mathbb{R}^{B \times L \times d_{\text{model}}}$
维度:$(B, L)$ 整数索引 $\xrightarrow{\text{lookup } W_e}$ $(B, L, d_{\text{model}})$ 浮点张量
老式绝对位置编码:$X_0 = W_e[\text{input\_ids}] + PE$;现代 LLM 常用 RoPE,此处通常不直接加 PE。
设本层输入 $X \in \mathbb{R}^{B \times L \times d_{\text{model}}}$
(1) Q/K/V 线性投影
通过三组可学习的参数矩阵做线性变换:
$Q = X\, W_Q, \quad K = X\, W_K, \quad V = X\, W_V$
$W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times (h \cdot d_h)}$ 为本层可学习参数(因 $h \cdot d_h = d_{\text{model}}$,实际就是 $d_{\text{model}} \times d_{\text{model}}$ 的方阵)
矩阵乘法维度:$\underbrace{(B, L, d_{\text{model}})}_{X} \times \underbrace{(d_{\text{model}}, h \cdot d_h)}_{W_Q} = \underbrace{(B, L, h \cdot d_h)}_{Q}$
reshape 为多头形式,再 transpose 把 head 维提前:
$(B, L, h \cdot d_h) \xrightarrow{\text{reshape}} (B, L, h, d_h) \xrightarrow{\text{transpose}} (B, h, L, d_h)$
$Q, K, V \in \mathbb{R}^{B \times h \times L \times d_h}$
(2) 位置编码(RoPE)
对 $Q, K$ 的每个位置 $\text{pos} = 0, 1, \ldots, L{-}1$ 施加旋转矩阵(逐元素,不改变维度):
$Q_r = \text{RoPE}(Q), \quad K_r = \text{RoPE}(K) \quad \in \mathbb{R}^{B \times h \times L \times d_h}$
RoPE 核心公式——将每两个维度视为 2D 向量,按位置做旋转:
$\begin{pmatrix} q_{2i}^{(t)} \\ q_{2i+1}^{(t)} \end{pmatrix} = \begin{pmatrix} \cos(t\theta_i) & -\sin(t\theta_i) \\ \sin(t\theta_i) & \cos(t\theta_i) \end{pmatrix} \begin{pmatrix} q_{2i} \\ q_{2i+1} \end{pmatrix}, \quad \theta_i = 10000^{-2i/d_h}$
关键性质:$Q_t^\top K_{t+\Delta} = (R_\theta^t q)^\top (R_\theta^{t+\Delta} k) = q^\top R_\theta^\Delta k$,内积只依赖相对位置差 $\Delta$,与绝对位置无关。$V$ 不做位置编码。
(3) 注意力分数
$Q_r$ 与 $K_r^\top$ 做矩阵乘法(在最后两维上):
$S = \dfrac{Q_r\, K_r^\top}{\sqrt{d_h}}$
$\underbrace{(B, h, L, d_h)}_{Q_r} \times \underbrace{(B, h, d_h, L)}_{K_r^\top} = \underbrace{(B, h, L, L)}_{S}$
$S$ 的含义:每个 head 上,$L$ 个 query 位置对 $L$ 个 key 位置的原始注意力分数。
加 causal mask(上三角置 $-\infty$)后 softmax:
$A = \text{softmax}(\text{mask}(S)) \in \mathbb{R}^{B \times h \times L \times L}$
Causal mask 形状(以 $L=5$ 为例):
k₀ k₁ k₂ k₃ k₄
q₀ ✓ · · · ·
q₁ ✓ ✓ · · ·
q₂ ✓ ✓ ✓ · ·
q₃ ✓ ✓ ✓ ✓ ·
q₄ ✓ ✓ ✓ ✓ ✓
下三角(含对角线)保留原 score,上三角置 $-\infty$,softmax 后变 0——确保位置 $t$ 只能看到位置 $\le t$ 的 token。
(4) 加权求和
$O_{\text{head}} = A\, V$
$\underbrace{(B, h, L, L)}_{A} \times \underbrace{(B, h, L, d_h)}_{V} = \underbrace{(B, h, L, d_h)}_{O_{\text{head}}}$
每个 query 位置得到 $d_h$ 维的加权上下文向量。
(5) 多头拼接 + 输出投影
将 $h$ 个 head 沿最后一维拼接:
$(B, h, L, d_h) \xrightarrow{\text{transpose}} (B, L, h, d_h) \xrightarrow{\text{reshape}} (B, L,\, h \cdot d_h)$
经过输出投影矩阵 $W_O \in \mathbb{R}^{(h \cdot d_h) \times d_{\text{model}}}$:
$O_{\text{attn}} = O_{\text{concat}}\, W_O$
$\underbrace{(B, L, h \cdot d_h)}_{O_{\text{concat}}} \times \underbrace{(h \cdot d_h, d_{\text{model}})}_{W_O} = \underbrace{(B, L, d_{\text{model}})}_{O_{\text{attn}}}$
(6) Pre-Norm 残差结构 + SwiGLU MLP
RMSNorm(现代 LLM 取代 LayerNorm)—— 只保留缩放,去掉减均值:
$\text{RMSNorm}(x) = \dfrac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \varepsilon}} \odot \gamma$
$\gamma \in \mathbb{R}^d$ 为可学习缩放参数;$\varepsilon = 10^{-6}$ 防止除零。比 LayerNorm 少一次均值计算,更快且效果相当。
Pre-Norm 残差(现代 LLM 标准结构,与原始 Transformer 的 Post-Norm 不同):
$X_{\text{mid}} = X + \text{Attention}\bigl(\text{RMSNorm}(X)\bigr)$
注:上文 Step 2(1)-(5) 中的 $X$ 实际是 $\text{RMSNorm}(X_{\text{原始}})$,为简化记号没有显式写出。Pre-Norm 让深层网络梯度更稳定,无需 warm-up。
SwiGLU(LLaMA 等现代模型的激活方案)—— 用门控机制替代单一激活:
$\text{SwiGLU}(x) = \bigl(\text{SiLU}(x W_g)\bigr) \odot (x W_u) \cdot W_d$
$W_g, W_u \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$ 为门控/升维矩阵,$W_d \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}$ 为降维矩阵;$\text{SiLU}(z) = z \cdot \sigma(z)$ 为 Sigmoid-Linear Unit
$\underbrace{(B, L, d_{\text{model}})}_{x} \xrightarrow{W_g,\,W_u} \underbrace{(B, L, d_{\text{ff}})}_{\text{gate},\,\text{up}} \xrightarrow{\text{SiLU} \odot} (B, L, d_{\text{ff}}) \xrightarrow{W_d} \underbrace{(B, L, d_{\text{model}})}_{\text{output}}$
对比传统 MLP $\sigma(x W_1) W_2$:SwiGLU 用两个投影 $W_g, W_u$ 替代一个,参数量略多但表达能力更强。
第二个 Pre-Norm 残差(套在 SwiGLU 上):
$X' = X_{\text{mid}} + \text{SwiGLU}\bigl(\text{RMSNorm}(X_{\text{mid}})\bigr) \in \mathbb{R}^{B \times L \times d_{\text{model}}}$
Pre-Norm vs Post-Norm:原始 Transformer 论文是 $X' = \text{Norm}(X + \text{Sublayer}(X))$(Norm 在残差之后);现代 LLM 全部用 $X' = X + \text{Sublayer}(\text{Norm}(X))$(Norm 在子层之前)。后者训练更稳。
$X_0 \to X_1 \to X_2 \to \cdots \to X_N$
每层的输入和输出 shape 恒为 $(B, L, d_{\text{model}})$,不变。
最后一层输出 $H \in \mathbb{R}^{B \times L \times d_{\text{model}}}$,经过 LM Head:
$Z = H\, W_{\text{lm}}$
$\underbrace{(B, L, d_{\text{model}})}_{H} \times \underbrace{(d_{\text{model}}, V)}_{W_{\text{lm}}} = \underbrace{(B, L, V)}_{Z}$
很多模型中 $W_{\text{lm}} = W_e^\top$(weight tying),无额外参数。
labels $\in \mathbb{Z}^{B \times L}$,对每个位置做词表大小 $V$ 的分类:
$\mathcal{L} = \text{CrossEntropy}(Z_{[:, :-1, :]},\ \text{labels}_{[:, 1:]})$
实际上是用位置 $t$ 的预测 $Z_t \in \mathbb{R}^V$ 去预测位置 $t{+}1$ 的 token。
$\text{input\_ids}\ (B,L) \;\xrightarrow{W_e}\; X_0\ (B,L,d) \;\xrightarrow{N \times \text{Block}}\; H\ (B,L,d) \;\xrightarrow{W_{\text{lm}}}\; Z\ (B,L,V) \;\to\; \text{CE loss}$
🚀 二、推理时(自回归 + KV Cache)
推理专用记号
- $L_p$ — prompt 长度(用户输入的 token 数)
- $t$ — 当前解码步的位置索引(从 $L_p$ 开始)
- $T$ — cache 中已有的 token 数(位置 $0 \ldots T{-}1$ 的 K/V 已缓存)
- 下标 $_t$ — "当前这一个 token" 的量(序列维 = 1)
- 下标 $_{\text{all}}$ — "历史 cache + 当前 token" 拼接后的量(序列维 = $T{+}1$)
- 下标 $_{\text{cache}}$ — "仅历史缓存" 的量(序列维 = $T$)
推理分两个阶段:
① Prefill(处理 prompt,一次性并行,下方 Step 0)
→
② Decode(逐 token 自回归,下方 Step 1-4,每步一个 token)。
两阶段用同一份模型参数,差异只在输入长度和是否带 cache。
用户给定 prompt:
$\text{prompt\_ids} \in \mathbb{Z}^{B \times L_p}$
这一步不是逐 token,而是把整段 prompt 一次性送进模型——shape 完全等同训练:
$(B, L_p) \xrightarrow{W_e} (B, L_p, d_{\text{model}}) \xrightarrow{N \times \text{Block (causal mask)}} (B, L_p, d_{\text{model}}) \xrightarrow{W_{\text{lm}}} (B, L_p, V)$
Prefill 阶段每层算出的 K/V 全部存入 cache:
$K_{\text{cache}}^{(i)},\, V_{\text{cache}}^{(i)} \in \mathbb{R}^{B \times h \times L_p \times d_h}, \quad i = 1, \ldots, N$
从 $(B, L_p, V)$ 的 logits 中只取最后一个位置:
$Z_{[:, -1, :]} \in \mathbb{R}^{B \times V} \;\xrightarrow{\text{sample}}\; \text{token}_{L_p}$
这就是模型生成的第一个新 token。Prefill 完成后 $T = L_p$,进入 Decode 阶段。
Prefill 是计算密集型(一次 $L_p \times L_p$ 注意力),Decode 是显存带宽密集型(每步只算 $1 \times (T{+}1)$,但要从 HBM 读全部 KV cache)——这是工业部署中两阶段独立优化的根本原因。
当前步只输入最新一个 token(不是整段序列):
$\text{input\_id}_t \in \mathbb{Z}^{B \times 1}$
经过嵌入矩阵 $W_e$(同训练时的同一个矩阵):
$X_t^{(0)} = W_e[\text{input\_id}_t] \in \mathbb{R}^{B \times 1 \times d_{\text{model}}}$
注意序列维 = $\mathbf{1}$(只有一个 token),而训练时序列维 = $L$
设本层输入 $X_t \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$(只有当前 1 个 token)
(1) 当前 token 的 Q/K/V 投影
使用与训练时完全相同的参数矩阵 $W_Q, W_K, W_V$:
$Q_t = X_t\, W_Q, \quad K_t = X_t\, W_K, \quad V_t = X_t\, W_V$
$\underbrace{(B, \mathbf{1}, d_{\text{model}})}_{X_t} \times \underbrace{(d_{\text{model}}, h \cdot d_h)}_{W_Q} = \underbrace{(B, \mathbf{1}, h \cdot d_h)}_{Q_t}$
reshape + transpose 为多头形式:
$(B, \mathbf{1}, h \cdot d_h) \xrightarrow{\text{reshape}} (B, \mathbf{1}, h, d_h) \xrightarrow{\text{transpose}} (B, h, \mathbf{1}, d_h)$
$Q_t, K_t, V_t \in \mathbb{R}^{B \times h \times \mathbf{1} \times d_h}$
关键:序列维度 = 1,因为只投影了当前这一个 token。训练时这里是 $L$。
(2) RoPE
对当前位置 $t$(不是位置 0!)施加旋转位置编码:
$Q_t \leftarrow \text{RoPE}(Q_t,\ \text{pos}{=}t), \quad K_t \leftarrow \text{RoPE}(K_t,\ \text{pos}{=}t)$
shape 不变:$(B, h, \mathbf{1}, d_h)$
历史 cache 里的 $K$ 在之前各步已经各自 RoPE 过了(pos=0, 1, ..., T-1),无需重做。
(3) KV Cache 原理与拼接
为什么需要 KV Cache?
如果不用 cache,每生成一个 token 都要重新算全部 $T{+}1$ 个历史位置的 $K, V$,计算量随序列长度平方增长。KV Cache 的核心思想:历史位置的 $K, V$ 已经算过,存起来复用。
$K_{\text{cache}}^{(t)} = \big[\, K^{(0)}, K^{(1)}, \ldots, K^{(t-1)} \,\big] \in \mathbb{R}^{B \times h \times t \times d_h}$
第 $t$ 步时,cache 中存放了位置 $0$ 到 $t{-}1$ 共 $t$ 个 token 的 Key(每步做完 RoPE 后立即存入)。
沿序列维(dim=2)拼接当前 token 的 $K_t, V_t$:
$K_{\text{all}} = \text{cat}(K_{\text{cache}},\, K_t)$, $\quad V_{\text{all}} = \text{cat}(V_{\text{cache}},\, V_t)$
$\underbrace{(B, h, T, d_h)}_{K_{\text{cache}}} \oplus \underbrace{(B, h, \mathbf{1}, d_h)}_{K_t} = \underbrace{(B, h, \mathbf{T{+}1}, d_h)}_{K_{\text{all}}}$
$Q_t$ 序列维仍是 1(只有当前 query);$K_{\text{all}}, V_{\text{all}}$ 序列维是 $T{+}1$(全部上下文)。这种不对称是 KV cache 的核心。
计算量对比(增量 vs 全量):
不用 cache:每步 $O((T{+}1)^2 \cdot d_h)$(全部 token 两两打分)
用 cache:每步 $O(1 \cdot (T{+}1) \cdot d_h)$(1 个 query $\times$ $T{+}1$ 个 key)
从 $O(T^2)$ 降到 $O(T)$,生成长文本时加速显著。代价是显存中要存 $2 \cdot N \cdot B \cdot h \cdot T \cdot d_h$ 个 float16 值($N$ 层,$K$ 和 $V$ 各一份)。
(4) 当前一步 Attention
注意力分数——当前 1 个 query 对全部 $T{+}1$ 个 key 打分:
$S_t = \dfrac{Q_t\, K_{\text{all}}^\top}{\sqrt{d_h}}$
$\underbrace{(B, h, \mathbf{1}, d_h)}_{Q_t} \times \underbrace{(B, h, d_h, \mathbf{T{+}1})}_{K_{\text{all}}^\top} = \underbrace{(B, h, \mathbf{1}, \mathbf{T{+}1})}_{S_t}$
对比训练时:$S \in \mathbb{R}^{B \times h \times L \times L}$($L$ 个 query $\times$ $L$ 个 key)。推理时只有 $1 \times (T{+}1)$。
softmax:
$A_t = \text{softmax}(S_t) \in \mathbb{R}^{B \times h \times \mathbf{1} \times (T{+}1)}$
推理无需 causal mask(当前 token 本来就只能看到自己和之前的 token)。
加权求和——用注意力权重对全部 $T{+}1$ 个 value 加权:
$O_t^{\text{head}} = A_t\, V_{\text{all}}$
$\underbrace{(B, h, \mathbf{1}, T{+}1)}_{A_t} \times \underbrace{(B, h, T{+}1, d_h)}_{V_{\text{all}}} = \underbrace{(B, h, \mathbf{1}, d_h)}_{O_t^{\text{head}}}$
输出序列维仍是 1——一个 query 得到一个 $d_h$ 维的上下文向量,但它融合了全部 $T{+}1$ 个位置的信息。
(5) 多头拼接 + 输出投影
$(B, h, \mathbf{1}, d_h) \xrightarrow{\text{transpose}} (B, \mathbf{1}, h, d_h) \xrightarrow{\text{reshape}} (B, \mathbf{1},\, h \cdot d_h)$
$O_{t,\text{attn}} = O_{t,\text{concat}}\, W_O$
$\underbrace{(B, \mathbf{1}, h \cdot d_h)}_{O_{t,\text{concat}}} \times \underbrace{(h \cdot d_h, d_{\text{model}})}_{W_O} = \underbrace{(B, \mathbf{1}, d_{\text{model}})}_{O_{t,\text{attn}}}$
(6) Pre-Norm 残差 + SwiGLU(与训练完全一致,仅序列维 = 1)
$X_{t,\text{mid}} = X_t + \text{Attention}\bigl(\text{RMSNorm}(X_t)\bigr)$
SwiGLU:$(B, \mathbf{1}, d_{\text{model}}) \xrightarrow{W_g,W_u} (B, \mathbf{1}, d_{\text{ff}}) \xrightarrow{\text{SiLU}\odot} (B, \mathbf{1}, d_{\text{ff}}) \xrightarrow{W_d} (B, \mathbf{1}, d_{\text{model}})$
$X_t' = X_{t,\text{mid}} + \text{SwiGLU}\bigl(\text{RMSNorm}(X_{t,\text{mid}})\bigr) \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$
同时更新本层 cache:$K_{\text{cache}} \leftarrow K_{\text{all}}$,$V_{\text{cache}} \leftarrow V_{\text{all}}$(序列维从 $T$ 增长到 $T{+}1$)。
$X_t^{(0)} \to X_t^{(1)} \to \cdots \to X_t^{(N)}$
每层的输入和输出 shape 恒为 $(B, \mathbf{1}, d_{\text{model}})$。
每层独立维护自己的 KV cache,共 $N$ 层,每层的 cache shape 为 $(B, h, T{+}1, d_h)$。
最终 hidden state $H_t \in \mathbb{R}^{B \times \mathbf{1} \times d_{\text{model}}}$,经过 LM Head:
$Z_t = H_t\, W_{\text{lm}}$
$\underbrace{(B, \mathbf{1}, d_{\text{model}})}_{H_t} \times \underbrace{(d_{\text{model}}, V)}_{W_{\text{lm}}} = \underbrace{(B, \mathbf{1}, V)}_{Z_t}$
挤掉序列维得到 $z = Z_t[:, 0, :] \in \mathbb{R}^{B \times V}$,通过下面的流水线产出 next_token:
① Temperature:$z' = z / \tau$
$\tau \to 0$ 等价 argmax(确定性);$\tau = 1$ 原始分布;$\tau \to \infty$ 趋近均匀分布
② Top-k:保留 $z'$ 中最大的 $k$ 个,其余置 $-\infty$
③ Top-p(nucleus):累积概率最先达到 $p$ 的最小集合,其余置 $-\infty$
④ Repetition penalty:已出现 token 的 logit 除以 $\rho > 1$(降低重复)
⑤ Softmax + 采样:$\text{next\_token} = \text{multinomial}(\text{softmax}(z'), 1)$
这个 next_token 就成为下一步的 $\text{input\_id}_{t+1}$,如此循环直到生成 EOS 或达到最大长度。
标准 MHA 中 $h_q = h_{kv} = h$(Q/K/V head 数相同)。Grouped-Query Attention 让 K/V head 数远小于 Q head 数:
$h_q = h, \quad h_{kv} \ll h, \quad \text{每 } h/h_{kv} \text{ 个 Q head 共享一组 K/V head}$
| 项目 | MHA(标准) | GQA | MQA(极端) |
| $W_K, W_V$ 形状 |
$d_{\text{model}} \times h \cdot d_h$ |
$d_{\text{model}} \times h_{kv} \cdot d_h$ |
$d_{\text{model}} \times d_h$ |
| K/V cache shape |
$(B, h, T, d_h)$ |
$(B, h_{kv}, T, d_h)$ |
$(B, 1, T, d_h)$ |
| Cache 显存倍数 |
1× |
$h_{kv}/h$ × (如 8/64 = 1/8) |
$1/h$ × (如 1/64) |
| 注意力计算前 |
— |
把 K/V 沿 head 维 repeat $h/h_{kv}$ 倍后再算(数学等价于 broadcast) |
实例:LLaMA 3 70B 用 GQA-8($h_q=64, h_{kv}=8$),KV cache 缩 8×;DeepSeek-V2 进一步用 MLA(Multi-head Latent Attention)压缩到 $\sim$1/16。这是长上下文部署的关键瓶颈优化。
$\text{KV mem (bytes)} = 2 \cdot N \cdot B \cdot h_{kv} \cdot T \cdot d_h \cdot s$
系数 2 因为 K 和 V 各一份;$s$ = 每元素字节数(fp16/bf16 = 2,fp8 = 1,int4 = 0.5)。
实例 1(LLaMA 3 8B,MHA):$N{=}32, h{=}32, d_h{=}128, B{=}1, T{=}8192, s{=}2$
$2 \times 32 \times 1 \times 32 \times 8192 \times 128 \times 2 = \mathbf{4.3\ \text{GB}}$
实例 2(LLaMA 3 70B,GQA-8):$N{=}80, h_{kv}{=}8, d_h{=}128, B{=}1, T{=}8192, s{=}2$
$2 \times 80 \times 1 \times 8 \times 8192 \times 128 \times 2 = \mathbf{2.7\ \text{GB}}$
KV cache 与 $T$ 线性增长,且不分 batch 共享——这是长上下文(128k)和高并发推理的首要显存瓶颈,PageAttention/vLLM 等系统正是为它而生。
$\text{prompt}\ (B, L_p) \xrightarrow{\text{Prefill}} \text{cache}\ (T{=}L_p) + \text{token}_{L_p} \xrightarrow{\text{Decode loop}} \text{token}_{L_p+1}, \text{token}_{L_p+2}, \ldots$
💻 五、全套 PyTorch 实现(与公式一一对应)
以下代码为自包含的纯 PyTorch 实现,可直接运行。每个代码块均标注了对应的前文公式编号。
class RMSNorm(nn.Module):
# 公式: y = x / sqrt(mean(x²) + ε) * weight
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# x: [B, T, C] → rms: [B, T, 1]
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x * rms * self.weight
维度映射:输入 [B, T, C](即前文 $(B, L, d_{\text{model}})$)→ rms 沿最后一维求均值 → [B, T, 1],逐元素乘回原张量。比 LayerNorm 少了减去均值的步骤,更高效。
class SwiGLU(nn.Module):
"""SwiGLU(x) = SiLU(gate_proj(x)) * up_proj(x)"""
def __init__(self, dim, hidden_dim):
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # 公式中 W₁
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # 公式中 W₂
def forward(self, x):
gate = self.gate_proj(x) # [B,T,C] → [B,T,d_ff]
up = self.up_proj(x) # [B,T,C] → [B,T,d_ff]
x = F.silu(gate) * up # SiLU(gate) × up,逐元素
x = self.down_proj(x) # [B,T,d_ff] → [B,T,C]
return x
维度映射:[B, T, d_model] $\xrightarrow{W_1}$ [B, T, d_ff] $\xrightarrow{\sigma}$ [B, T, d_ff] $\xrightarrow{W_2}$ [B, T, d_model],与前文 MLP 维度链完全一致。现代 LLM(LLaMA 等)用 SwiGLU 替代 ReLU/GeLU。
def build_rope_cache(head_dim, seq_len, theta=10000.0, offset=0, device=None):
# freq: [D/2] | pos: [T] | angle: [T, D/2]
freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
pos = torch.arange(offset, offset + seq_len).float()
angle = torch.outer(pos, freq)
cos = angle.cos()[None, None, :, :] # [1, 1, T, D/2]
sin = angle.sin()[None, None, :, :]
return cos, sin
def apply_rope(x, cos, sin):
# x: [B,H,T,D] | cos,sin: [1,1,T,D/2]
xe = x[..., 0::2] # 偶数维 [B,H,T,D/2]
xo = x[..., 1::2] # 奇数维 [B,H,T,D/2]
out_even = xe * cos - xo * sin # 2D 旋转矩阵的第一行
out_odd = xe * sin + xo * cos # 2D 旋转矩阵的第二行
out = torch.empty_like(x)
out[..., 0::2] = out_even
out[..., 1::2] = out_odd
return out
对应公式:旋转矩阵 $R(\theta_t) = \begin{pmatrix}\cos\theta_t & -\sin\theta_t \\ \sin\theta_t & \cos\theta_t\end{pmatrix}$,每两个维度一组做 2D 旋转。offset 参数使推理时当前位置从 $t$ 开始(不是从 0),对应前文"推理 Step 2(2)"的 RoPE(pos=$t$)。
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, n_heads):
self.n_heads = n_heads
self.head_dim = dim // n_heads
# 三组参数矩阵(训练/推理共享)
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x, attention_mask=None, past_kv=None, use_cache=False):
B, T, C = x.shape
# 1. QKV 投影: [B,T,C] × [C,C] → [B,T,C]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# 2. reshape + transpose → [B, H, T, D]
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# 3. RoPE(推理时 offset 使 pos 从 t 开始)
past_len = 0 if past_kv is None else past_kv[0].shape[2]
cos, sin = build_rope_cache(self.head_dim, T, offset=past_len)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# 4. 拼接 KV Cache(仅推理时)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2) # [B,H,T+1,D]
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
# 5. Attention 分数: [B,H,T,D] × [B,H,D,S] → [B,H,T,S]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 6. Causal Mask(仅训练时需要)
if attention_mask is not None:
scores = scores.masked_fill(~attention_mask, float("-inf"))
# 7. Softmax + 8. 加权求和
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # [B,H,T,D]
# 9. merge heads + 10. 输出投影
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.o_proj(out) # [B,T,C]
return out, present_kv
训练 vs 推理维度对比:
训练:$T = L$(整段),$q\ [B,h,L,d_h] \times k^\top\ [B,h,d_h,L] \to$ scores [B,h,L,L],需 causal mask
推理:$T = 1$(单 token),$q_t\ [B,h,1,d_h] \times k_{\text{all}}^\top\ [B,h,d_h,T{+}1] \to$ scores [B,h,1,T{+}1],无需 mask
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads, mlp_hidden_dim):
self.attn_norm = RMSNorm(dim)
self.attn = MultiHeadSelfAttention(dim, n_heads)
self.ffn_norm = RMSNorm(dim)
self.ffn = SwiGLU(dim, mlp_hidden_dim)
def forward(self, x, attention_mask=None, past_kv=None, use_cache=False):
# Pre-Norm 风格:先归一化,再 attention
h = self.attn_norm(x)
attn_out, present_kv = self.attn(h, attention_mask, past_kv, use_cache)
x = x + attn_out # 残差连接
# Pre-Norm:先归一化,再 MLP
h = self.ffn_norm(x)
x = x + self.ffn(h) # 残差连接
return x, present_kv
对应公式链:LayerNorm(x) → Attn → x + attn_out → LayerNorm → MLP → x + mlp_out,与前文训练 Step 2(6) 的 $X_{\text{mid}} = \text{LayerNorm}(X + O_{\text{attn}})$ 完全一致。use_cache=True 时返回 $(K_{\text{all}}, V_{\text{all}})$ 供下一步使用。
class TinyTransformer(nn.Module):
def __init__(self, vocab_size, dim, n_heads, mlp_dim, n_layers):
self.embed = nn.Embedding(vocab_size, dim) # W_e
self.layers = nn.ModuleList([ # N 层
TransformerBlock(dim, n_heads, mlp_dim) for _ in range(n_layers)
])
self.norm = RMSNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False) # W_lm
def forward(self, input_ids, attention_mask=None, past_kvs=None, use_cache=False):
x = self.embed(input_ids) # [B,T,V] → [B,T,C] ← Step 1
new_kvs = []
for layer, past_kv in zip(self.layers, past_kvs):
x, present_kv = layer(x, attention_mask, past_kv, use_cache) # Step 2×N
new_kvs.append(present_kv)
x = self.norm(x)
logits = self.lm_head(x) # [B,T,C] → [B,T,V] ← Step 4
return logits, new_kvs if use_cache else None
使用示例对比:
# 训练:整段序列 + causal mask
logits, _ = model(input_ids, attention_mask=causal_mask, use_cache=False)
# logits: [B, L, V] ← 对应训练 Step 4
# 推理:逐 token + KV cache
for step in range(4):
logits, past_kvs = model(step_input, use_cache=True)
# logits: [B, 1, V] ← 对应推理 Step 4
# ============ 1. 初始化 ============
model = TinyTransformer(
vocab_size=32000, dim=4096, n_heads=32,
mlp_dim=11008, n_layers=32,
).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100000)
# ============ 2. 训练循环 ============
for epoch in range(num_epochs):
for batch in dataloader:
input_ids = batch["input_ids"].cuda() # [B, L]
labels = batch["labels"].cuda() # [B, L](= input_ids 左移1位)
# --- Forward ---
causal_mask = torch.tril(torch.ones(L, L, device="cuda")) # [L,L]
logits, _ = model(input_ids, attention_mask=causal_mask, use_cache=False)
# logits: [B, L, V]
# --- Loss: 用 t 位置预测 t+1 位置的 token ---
loss = F.cross_entropy(
logits[:, :-1].reshape(-1, vocab_size), # [B*(L-1), V]
labels[:, 1:].reshape(-1), # [B*(L-1)]
)
# --- Backward ---
optimizer.zero_grad()
loss.backward() # 反向传播,填充 .grad
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# --- Update ---
optimizer.step() # AdamW 更新参数
scheduler.step()
训练全流程拆解:
Embedding → N×(Attn+MLP) → LM Head → logits [B,L,V] → CrossEntropy → backward() → AdamW.step()
关键细节:labels 是 input_ids 左移 1 位(位置 $t$ 预测位置 $t{+}1$),logits[:, :-1] 去掉最后一个位置(没有下一个 token 可预测)。
# ============ 推理模式:逐步自回归生成 ============
model.eval()
past_kvs = None
generated = [bos_token_id] # 以 BOS 开头
for step in range(max_new_tokens):
step_input = torch.tensor([generated[-1]]).cuda() # [1,1] 当前 token
logits, past_kvs = model(
step_input[None, :], # [B=1, T=1]
past_kvs=past_kvs, use_cache=True
)
# logits: [1, 1, V]
next_token = torch.argmax(logits[:, -1], dim=-1).item()
generated.append(next_token)
if next_token == eos_token_id:
break
# past_kvs 逐层膨胀: 每层 (B,H,S,D),S 从 1 增长到 T
# 总显存: 2 × N_layers × B × H × T × D × 2 bytes (fp16)
推理全流程拆解:
Embed 1 token → N×(Attn+KV Cache+MLP) → LM Head → logits [B,1,V] → argmax/sample → 新 token → 循环
每步只算 1 个 token,历史 $K,V$ 从 cache 读取不重算。生成 $T$ 个 token 的总计算量 = $O(T^2 \cdot d \cdot h)$,比无 cache 的 $O(T^3 \cdot d \cdot h)$ 快一个数量级。