04: 训练与微调
核心主题:Pretrain、SFT、LoRA/QLoRA、RLHF/DPO/GRPO、分布式训练
来源:Happy-LLM 第四/六章 + Base-LLM 第十一/十二章
1. 三阶段训练流程
| 阶段 | 目标 | 数据 | 方法 |
|---|---|---|---|
| Pretrain | 学习世界知识 | 万亿级无标注文本 | CLM 自回归 |
| SFT | 学习指令遵循 | 10万级指令-回复对 | 监督微调 |
| RLHF/DPO | 对齐人类偏好 | 偏好对 (chosen/rejected) | 强化学习/直接偏好优化 |
直觉理解
- Pretrain:学会"说话"(语言能力 + 世界知识)
- SFT:学会"听话"(按指令格式输出)
- RLHF:学会"说好话"(有用、真实、无害)
2. 预训练 (Pretrain)
2.1 数据与目标
数据处理:Tokenize → 拼接所有文本 → 切分为固定长度块(如 2048/4096 tokens)
数据规模:
- GPT-3: 300B tokens, LLaMA-2: 2T tokens, LLaMA-3: 15T tokens
- Chinchilla 最优:$D = 20N$
2.2 关键训练技巧
| 技巧 | 作用 |
|---|---|
| Cosine LR Schedule + Warmup | 稳定训练初期,后期平滑衰减 |
| Gradient Accumulation | 模拟大 batch(accumulate K steps = batch × K) |
| Mixed Precision (BF16) | 节省显存,加速计算(前向 BF16,梯度 FP32) |
| Gradient Checkpointing | 用计算换显存(~30% 慢,~60% 省显存) |
| Flash Attention | 减少 HBM 读写,IO-aware 分块计算 |
3. SFT (Supervised Fine-Tuning)
3.1 数据格式(Qwen/ChatML 风格)
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
{用户问题}<|im_end|>
<|im_start|>assistant
{模型回答}<|im_end|>
3.2 Loss Mask
核心区别:SFT vs Pretrain
SFT 只在 assistant 回复部分计算 loss,system/user 部分用 ignore_index=-100 屏蔽。
多轮对话最佳实践:同时预测所有 assistant turns(而非只预测最后一轮)。
# Loss mask: only compute on assistant tokens
labels = input_ids.clone()
# Mask everything except assistant responses
labels[~assistant_mask] = -100 # IGNORE_TOKEN_ID
loss = F.cross_entropy(logits, labels, ignore_index=-100)
4. PEFT 参数高效微调总览
4.1 方法对比
| 方法 | 机制 | 额外推理延迟 | 适用规模 |
|---|---|---|---|
| Adapter | 串行插入瓶颈层 (down→up+residual) | 有(串行) | 全规模 |
| Prefix Tuning | 每层加可学习 prefix K/V | 有 | 全规模 |
| Prompt Tuning | 仅输入层加 soft tokens | 极小 | 10B+ 才有效 |
| P-Tuning v2 | 每层加 deep prompts | 有 | 全规模 |
| LoRA | 并行低秩矩阵分解 | 无(可合并) | 全规模 |
5. LoRA 详解
5.1 原理与公式
核心假设
微调时的权重更新 $\Delta W$ 具有低内在秩(intrinsic rank)。
初始化:$A$ ~ Gaussian, $B$ = 0 → 初始 $\Delta W = 0$(从预训练起点开始)
关键超参:
| 超参 | 含义 | 典型值 |
|---|---|---|
r | 秩(可训练参数量 $\propto r$) | 4, 8, 16 |
lora_alpha | 缩放因子(实际缩放 = alpha/r) | 16, 32 |
target_modules | 应用到哪些层 | ["q_proj", "v_proj"] 或全部 |
lora_dropout | 正则化 | 0.05, 0.1 |
可训练参数量:$\Theta = 2 \times L_{target} \times d \times r$(通常 <1% 总参数)
LoRA 核心优势
- 零推理延迟:部署时合并 $W' = W_0 + \frac{\alpha}{r} BA$
- 高效:冻结参数无需梯度/优化器状态
- 模块化:不同任务换不同 LoRA adapter
- 可组合:与量化等其他技术正交
5.2 QLoRA
核心创新:Base model 量化为 4-bit NF4,LoRA adapter 保持 BF16。
- NF4(Normal Float 4):信息论最优的4位量化(针对正态分布权重)
- Double Quantization:量化参数本身也量化,进一步压缩
- Paged Optimizers:CPU offload 优化器状态
- 效果:65B 模型可在单张 48GB GPU 上微调
5.3 AdaLoRA
动态分配 LoRA rank:重要层高 rank,不重要层低 rank。
6. 偏好对齐
6.1 RLHF (PPO)
四个模型:
- Actor(SFT初始化,更新):生成回复的策略模型
- Ref Model(SFT初始化,冻结):KL 约束的参考
- Reward Model(冻结):评分模型
- Critic(RM初始化,更新):估计 value function
Reward Model 训练(Bradley-Terry):
PPO 目标:
$r_t = \pi_\theta(a_t|s_t) / \pi_{old}(a_t|s_t)$,KL 惩罚防止过度偏离:
$$\text{obj} = \mathbb{E}\left[r_\theta(x,y) - \beta \text{KL}(\pi_\theta || \pi_{ref})\right]$$6.2 DPO (Direct Preference Optimization)
DPO 核心思想
跳过显式 Reward Model,直接从偏好数据优化策略。将 RL 问题转化为分类问题。
隐式奖励:$\hat{r}(x,y) = \beta\log\frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$
| 对比 | RLHF (PPO) | DPO |
|---|---|---|
| 模型数量 | 4 | 2(Actor + Ref) |
| Reward Model | 需要显式训练 | 隐式 |
| 在线采样 | 需要 | 不需要(离线数据) |
| 训练稳定性 | 较难调 | 简单(类似 SFT) |
| 效果上限 | 更高(在线探索) | 受限于离线数据分布 |
6.3 GRPO (Group Relative Policy Optimization)
DeepSeek 提出:去掉 Critic,用组内相对排名作为 advantage。
对每个 prompt 采样 $G$ 个回复,用 verifier(如规则/代码执行)评分,组内标准化作为 advantage。
优势:省 Critic 模型(节省显存),特别适合有 verifier 的数学/代码场景。
7. DeepSpeed ZeRO 分布式训练
| ZeRO Stage | 分片内容 | 单卡显存 | 通信量 |
|---|---|---|---|
| ZeRO-1 | Optimizer States | $4\Phi + 12\Phi/N$ | 低 |
| ZeRO-2 | + Gradients | $2\Phi + (2+12)\Phi/N$ | 中 |
| ZeRO-3 | + Parameters | $(2+2+12)\Phi/N$ | 高 |
$\Phi$ = 模型参数量,$N$ = GPU 数量
实践建议
- SFT:ZeRO-2(显存/通信平衡最优)
- 超大模型预训练:ZeRO-3 + Tensor Parallelism
- LoRA 微调:ZeRO-2 甚至单卡即可
面试要点总结
高频面试题
- LoRA 为什么有效? 微调更新矩阵具有低内在秩;且只需学习"预训练未强调但下游关键"的特征
- LoRA rank 怎么选? 简单任务 r=4-8,复杂任务 r=16-64;多层 low-r > 单层 high-r
- LoRA 不适合什么场景? 知识注入(CPT/Pretrain);需要全面改变模型行为的场景
- RLHF 为何需要4个模型? Ref 防遗忘 pretrain 能力,Reward 提供信号,Critic 估计 baseline 减方差
- DPO vs RLHF? DPO 更简单稳定,但缺少在线探索;RLHF 效果上限更高但难训
- GRPO 的创新? 去 Critic(省显存),用组内相对排名替代 value baseline
- Reward Hacking? 模型学会钻 RM 漏洞(如冗长但空洞的回答得高分),需 KL 约束
- SFT 和 Pretrain 的 Loss 区别? SFT 仅在 assistant tokens 上计算 loss (loss mask)
- Gradient Checkpointing 原理? 不存前向中间结果,反向时重算,省 ~60% 显存代价 ~30% 时间