03: 大模型架构详解
核心主题:LLaMA2 实现、GQA、SwiGLU、MoE、解码策略、KV Cache
来源:Happy-LLM 第五章 + Base-LLM 第六章
1. 现代 LLM 架构总览
以 LLaMA2 为代表的现代 LLM 架构公式:
LLM = Decoder-Only + Pre-Norm(RMSNorm) + RoPE + GQA + SwiGLU
数据流:
Token Embedding
|
v
N x [ RMSNorm -> GQA(+RoPE) -> Residual -> RMSNorm -> SwiGLU -> Residual ]
|
v
Final RMSNorm -> Linear Head -> Logits
2. RMSNorm
vs LayerNorm:省去减均值步骤,计算更快,效果持平。
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return norm * self.weight
3. GQA (Grouped-Query Attention)
核心思想
Q 保持原始头数,K/V 使用更少的头数(分组共享)。减少 KV-Cache 显存占用而几乎不损失质量。
| 注意力类型 | Q Heads | KV Heads | KV Cache 大小 |
|---|---|---|---|
| MHA (Multi-Head) | $h$ | $h$ | $2 \times n \times h \times d_k$ |
| GQA (Grouped-Query) | $h$ | $g$ ($g < h$) | $2 \times n \times g \times d_k$ |
| MQA (Multi-Query) | $h$ | $1$ | $2 \times n \times d_k$ |
LLaMA2-70B:64 Q heads, 8 KV heads → KV-cache 减少 8 倍。
# GQA: expand KV to match Q head count
def repeat_kv(x, n_rep):
"""(B, n_kv_heads, T, d) -> (B, n_heads, T, d)"""
if n_rep == 1:
return x
return x[:, :, None, :, :].expand(B, n_kv_heads, n_rep, T, d) \
.reshape(B, n_kv_heads * n_rep, T, d)
4. SwiGLU FFN
SwiGLU 公式展开(逐步维度推导)
设输入 $x \in \mathbb{R}^{B \times T \times d}$,隐层维度 $d_{ff} = \frac{8d}{3}$:
Step 1 — 门控分支(gate)和上投影分支(up):
说明:$W_1, W_2, W_3$ 均为可学习权重矩阵,无偏置(LLaMA 所有线性层都不加 bias,因为后续有 LayerNorm)。$W_1$ 产生门控信号,$W_3$ 产生待过滤信息,$W_2$ 将门控后的结果投影回 $d$ 维。
Step 2 — 门控激活(Swish / SiLU):
说明:Swish(Google, 2017)和 SiLU(Sigmoid-weighted Linear Unit)是同一个函数:Swish(x) = x · σ(x) = x / (1 + e^(-x))。PyTorch 中叫 F.silu(),LLaMA 论文中叫 Swish。它是 ReLU 的平滑版本:正半轴近似线性,负半轴有微小激活而非归零。
Step 3 — 门控 × 上投影(逐元素乘法):
说明:$\otimes$ 表示逐元素乘法(element-wise multiply,等价于 `*`)。SwiGLU 名称即来自 Swish + GLU(Gated Linear Unit)的组合:用 Swish 激活值作为"门",控制上投影 $u$ 有多少信息通过。
Step 4 — 下投影(down-projection)回 $d$ 维:
说明:将门控后的 $d_{ff}$ 维特征投影回输入维度 $d$,保证 FFN 输出 shape 与输入一致(残差连接需要)。输出 $y$ 经过残差加回:$x_{\text{out}} = x + y$。
vs 标准 FFN:
- 标准:2 个线性层 ($W_1, W_2$),隐层 $4d$
- SwiGLU:3 个线性层 ($W_1, W_2, W_3$),隐层 $\frac{8d}{3}$(对齐到硬件友好的倍数)
- 参数量相当,但性能更好(门控机制控制信息流)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # gate
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # down
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # up
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
5. LLaMA2 完整配置
| 参数 | 7B | 13B | 70B |
|---|---|---|---|
| Layers | 32 | 40 | 80 |
| Hidden Size | 4096 | 5120 | 8192 |
| Q Heads | 32 | 40 | 64 |
| KV Heads | 32 (MHA) | 40 (MHA) | 8 (GQA) |
| Vocab Size | 32,000 | ||
| Context | 4,096 | ||
| 训练数据 | 2T tokens | ||
关键实现细节:
- Weight Tying:Embedding 和 Output 共享权重,减少参数
- 初始化:Normal(0, 0.02);残差投影缩放 $1/\sqrt{2N}$
- Flash Attention:IO-aware 算法,减少 HBM 访问
6. MoE (Mixture of Experts)
6.1 核心原理
稀疏激活:每个 token 只激活部分专家,实现"参数量大但计算量小"。
MoE 替换 Transformer Block 中的 FFN 层(通常隔层替换)。
6.2 路由机制与负载均衡
Router:$p(x) = \text{softmax}(xW_g)$,选择 Top-K 个专家。
负载均衡损失(防止专家坍塌):
$f_i$ = 实际分配比例, $P_i$ = 平均路由概率
6.3 代表模型
| 模型 | 总参数 | 活跃参数 | 专家数 | Top-K |
|---|---|---|---|---|
| Switch Transformer | 1.6T | - | 128 | 1 |
| Mistral 8x7B | 47B | 13B | 8 | 2 |
| DeepSeek-V2 | 236B | 21B | 160 | 6 |
| Qwen2-MoE | 57B | 14B | 64 | 8 |
DeepSeekMoE 创新:共享专家 + 路由专家
$$y = \underbrace{\sum_{i \in S} E_i(x)}_{\text{Shared (always active)}} + \underbrace{\sum_{j \in \text{TopK}(R)} p_j(x) \cdot E_j(x)}_{\text{Routed (sparse)}}$$共享专家处理通用知识,路由专家处理特定能力。
7. 解码策略
7.1 四种基本策略
| 策略 | 参数设置 | 特点 |
|---|---|---|
| Greedy | do_sample=False, num_beams=1 | 确定性,always argmax |
| Sampling | do_sample=True | 按概率采样,有随机性 |
| Beam Search | num_beams>1 | 维护多条候选,选全局最优 |
| Beam+Sample | num_beams>1, do_sample=True | Beam 框架 + 采样随机性 |
7.2 采样参数
Temperature:
- $T < 1$:分布更尖锐(更确定性)
- $T > 1$:分布更平坦(更随机)
- $T \to 0$:退化为 Greedy
Top-k:只保留概率最高的 K 个 token,其余设为 $-\infty$
Top-p (Nucleus):保留累积概率 $\geq p$ 的最小 token 集合(动态候选数)
8. KV Cache
KV Cache 原理
自回归生成时,已生成 token 的 K/V 不变。缓存它们避免重复计算。
- 无缓存:生成第 $t$ 个 token 时计算所有 $t$ 个位置的 K/V → $O(t^2)$ 总计算
- 有缓存:只计算新 token 的 K/V,拼接到缓存 → $O(t)$ 总计算
显存占用:$2 \times L \times n_{kv} \times d_k \times \text{seq\_len} \times \text{batch} \times \text{dtype\_bytes}$
面试要点总结
高频面试题
- GQA vs MHA vs MQA? MQA: 1个KV头,极端压缩但质量下降;GQA: 折中方案,几个Q头共享一个KV头
- RoPE 的优势? 相对位置编码、无额外参数、支持长度外推(NTK-aware扩展)
- SwiGLU 为什么比 ReLU 好? 门控机制允许模型学习"什么信息该通过",更平滑的梯度
- MoE 的 VRAM 问题? 虽然只激活部分专家,但所有专家参数都需加载到显存
- KV Cache 大小怎么算? $2 \times L \times n_{kv\_heads} \times d_{head} \times \text{seq\_len}$ (per sample)
- Temperature 和 Top-p 的区别? Temperature 控制分布锐度,Top-p 控制候选池大小,两者互补
- Weight Tying 的好处? 减少 $V \times d$ 参数量,且语义空间一致性(输入输出共享)
- MoE 中专家是否按领域分工? 实验显示更多按语法/token结构分,非语义领域