12 训练基础设施与分布式系统
系统性覆盖大模型训练 Infra 核心原理:显存分析、并行策略、参数高效微调、Attention优化、精度策略
1. 训练显存分析
1.1 显存组成
训练大模型时,GPU 显存主要被以下四部分占据:
| 类别 | 内容 | 典型比例 |
|---|---|---|
| 模型参数 (Parameters) | 权重矩阵 W | 占比约 1/4(混合精度下) |
| 梯度 (Gradients) | 反向传播的梯度 | 与参数等大 |
| 优化器状态 (Optimizer States) | Adam: m (一阶动量) + v (二阶动量) + master weights | 参数量的 2~3 倍 |
| 激活值 (Activations) | 前向传播中间结果,留给反向用 | 随 batch size 和 seq_len 增长 |
1.2 计算公式
Adam + 混合精度训练的显存估算
设模型参数量为 $\Phi$(单位:参数数):
- FP16 参数:$2\Phi$ 字节
- FP16 梯度:$2\Phi$ 字节
- FP32 master weights:$4\Phi$ 字节
- FP32 一阶动量 m:$4\Phi$ 字节
- FP32 二阶动量 v:$4\Phi$ 字节
总计:$16\Phi$ 字节(不含激活值)
例如 7B 模型:$16 \times 7 \times 10^9 = 112$ GB(仅优化器+参数+梯度)
激活值显存
激活值显存与 batch_size $\times$ seq_len $\times$ hidden_size $\times$ num_layers 成正比。常用优化手段:
- Gradient Checkpointing (激活重计算):用时间换空间,只保留部分 checkpoint 层的激活值,其余反向时重算
- Selective Activation Recomputation:只重算 Attention 部分(占比大但计算相对便宜)
2. 数据并行 (DDP)
2.1 原理
DistributedDataParallel (DDP) 是 PyTorch 最基础的多卡并行方式:
- 每张卡持有完整模型副本
- 数据按 batch 维度切分给不同卡
- 各卡独立前向+反向
- 反向结束时,通过 AllReduce 同步梯度
- 各卡用同步后的梯度更新参数,保持一致
DDP vs DataParallel (DP)
- DP:单进程多线程,存在 GIL 瓶颈,GPU0 负载不均衡(汇聚梯度)
- DDP:多进程,每卡一个进程,通信用 NCCL,无 GIL 问题,负载均衡
工业界几乎不再使用 DP,DDP 是数据并行的标准做法。
2.2 通信机制
DDP 的关键通信操作是 AllReduce:
- Ring AllReduce:将 N 张卡组成环,分 2(N-1) 步完成,通信量与卡数无关,仅与数据量相关
- 通信量:$2 \times \frac{(N-1)}{N} \times \text{data\_size}$,近似 $2 \times \text{data\_size}$
Gradient Bucketing
DDP 不是等所有梯度算完再通信,而是将参数分桶 (bucket),一桶算完就开始 AllReduce,实现计算与通信重叠。
3. DeepSpeed ZeRO
ZeRO (Zero Redundancy Optimizer) 的核心思想:在数据并行的基础上,将模型状态分片 (partition) 到各卡,消除冗余。
3.1 ZeRO Stage 1 / 2 / 3
| Stage | 切分内容 | 显存节省 | 通信量 |
|---|---|---|---|
| Stage 1 | 优化器状态 (OS) | 从 $16\Phi$ 降至 $4\Phi + \frac{12\Phi}{N}$ | 同 DDP |
| Stage 2 | 优化器状态 + 梯度 (OS+G) | 从 $16\Phi$ 降至 $2\Phi + \frac{14\Phi}{N}$ | 同 DDP |
| Stage 3 | 优化器状态 + 梯度 + 参数 (OS+G+P) | 从 $16\Phi$ 降至 $\frac{16\Phi}{N}$ | 约 1.5× DDP |
ZeRO 各 Stage 详解
Stage 1:每张卡只存 $1/N$ 的优化器状态。更新时各卡只更新自己分片对应的参数,再 AllGather 同步完整参数。
Stage 2:在 Stage 1 基础上,梯度也分片。每个参数的梯度只在"负责该参数的卡"上做 Reduce(用 ReduceScatter 代替 AllReduce),减少显存中的梯度副本。
Stage 3:连参数本身也分片。前向和反向需要某层参数时,通过 AllGather 临时收集,用完释放。代价是通信量增加约 50%(多了前向时的 AllGather)。
面试关键点
- Stage 1/2 通信量与 DDP 相同(都是 $2\Phi$ 数据量),但显存显著降低
- Stage 3 能线性扩展到任意大模型,但通信开销增加
- 实际使用中 Stage 2 是性价比最高的选择(大多数场景够用且无额外通信)
3.2 ZeRO-Offload
将部分计算和存储卸载到 CPU/NVMe:
- CPU Offload:优化器状态和梯度放 CPU 内存,GPU 只保留前向/反向必需的参数和激活
- 参数更新在 CPU 上完成(Adam step),更新后传回 GPU
- 适合单卡/少卡训练大模型(如单 A100 训 13B)
3.3 ZeRO-Infinity
ZeRO-Infinity 在 Offload 基础上进一步利用 NVMe SSD:
- 所有模型状态都可以 offload 到 NVMe
- 通过 prefetch 和 overlap 隐藏 I/O 延迟
- 理论上支持训练万亿参数模型(单节点)
Offload 的代价
CPU offload 会引入 PCIe 传输延迟和 CPU 计算瓶颈,吞吐量通常下降 30-50%。适合"能训起来"比"训得快"更重要的场景。
4. FSDP (Fully Sharded Data Parallel)
4.1 原理与对比
FSDP 是 PyTorch 原生实现的类 ZeRO Stage 3 方案:
- 将模型参数、梯度、优化器状态全部分片到各卡
- 前向时 AllGather 收集当前层参数,反向时 ReduceScatter 分发梯度
- 原生集成在 PyTorch 中,无需额外框架
| 维度 | DeepSpeed ZeRO-3 | FSDP |
|---|---|---|
| 生态 | 独立库,配置文件驱动 | PyTorch 原生,API 驱动 |
| 分片粒度 | 参数级 | Module 级 (FlatParameter) |
| 混合精度 | 自带 FP16/BF16 支持 | 与 torch.amp 集成 |
| Offload | CPU + NVMe | CPU(NVMe 有限支持) |
| 调试友好度 | 配置较多 | 更贴近 PyTorch 风格 |
| 适用场景 | 超大规模、极致优化 | 中大规模、快速上手 |
4.2 Sharding 策略
FSDP 提供三种 sharding 策略:
- FULL_SHARD:等价于 ZeRO-3,参数+梯度+优化器全分片
- SHARD_GRAD_OP:等价于 ZeRO-2,仅分片梯度和优化器状态
- NO_SHARD:等价于 DDP,不分片
FSDP Hybrid Sharding
在多机场景下,可以节点内 FULL_SHARD + 节点间 NO_SHARD(或反之),减少跨机通信。这等价于 ZeRO++ 的 hpZ (hierarchical partitioning)。
5. 张量并行 (Tensor Parallelism)
5.1 Megatron-LM 实现
张量并行 (TP) 的核心思想:将单个矩阵运算切分到多张卡上并行计算,是模型内并行 (intra-layer parallelism)。
Megatron-LM 是 NVIDIA 实现 TP 的经典框架,核心设计:
- 将 Transformer 每层的 MLP 和 Attention 的权重矩阵切分
- 利用矩阵乘法的可拆分性,分配到不同 GPU
- 每层只需 2 次 AllReduce(前向 1 次 + 反向 1 次)
5.2 列切分 (Column Parallel) 与行切分 (Row Parallel)
MLP 层的张量并行
标准 MLP:$Y = \text{GeLU}(XA) \cdot B$
第一个线性层(Column Parallel):
将 $A$ 按列切分为 $[A_1, A_2]$,每张卡计算 $Y_i = \text{GeLU}(XA_i)$
无需通信,因为 GeLU 是逐元素操作
第二个线性层(Row Parallel):
将 $B$ 按行切分,每张卡计算部分结果,最后 AllReduce 求和
$$Y = [Y_1, Y_2] \begin{bmatrix} B_1 \\ B_2 \end{bmatrix} = Y_1 B_1 + Y_2 B_2$$
Self-Attention 的张量并行
QKV 投影矩阵按 head 维度切分:
- 卡 $i$ 负责 head $i$ 的 $W_Q^i, W_K^i, W_V^i$
- 各卡独立计算 Attention
- 输出投影 $W_O$ 用 Row Parallel,AllReduce 聚合
总通信:每个 Transformer 层前向 2 次 AllReduce,反向 2 次 AllReduce。
TP 的限制
- 要求卡间高带宽(通常限制在单机 NVLink 内)
- TP degree 通常 ≤ 8(一台机器的 GPU 数)
- 通信频率高:每层都要通信,对延迟敏感
6. 流水线并行 (Pipeline Parallelism)
6.1 GPipe 与 1F1B
流水线并行 (PP) 将模型按层切分到不同卡上:
| 方案 | 调度方式 | 特点 |
|---|---|---|
| Naive PP | 串行执行全部前向再全部反向 | bubble 极大,几乎无用 |
| GPipe | 将 mini-batch 切为 micro-batches,连续注入流水线 | 减少 bubble,但需存所有 micro-batch 的激活 |
| 1F1B (Interleaved) | 一个 micro-batch 前向完立刻反向 | 峰值显存更低,bubble 不变 |
| Interleaved 1F1B | 每卡分配多个不连续的层块 | Bubble 缩小为原来的 $1/v$(v 为 chunks 数) |
6.2 Bubble 率分析
Pipeline Bubble
设 $p$ 为 pipeline stages 数,$m$ 为 micro-batches 数:
$$\text{Bubble ratio} = \frac{p - 1}{m + p - 1}$$
要使 bubble 率 < 5%,需要 $m \geq 20p$(经验值约 $m \geq 4p$ 时 bubble < 20%)
这意味着 PP 需要较大的 global batch size,否则 bubble 开销太大。
PP 的通信特点
PP 只需在相邻 stage 之间传递激活值(点对点通信),通信量远小于 TP。因此 PP 适合跨机场景(带宽要求低),而 TP 适合机内(带宽要求高)。
7. 3D 并行与混合策略
大规模训练(如 GPT-3 175B, LLaMA-70B)通常组合多种并行:
3D 并行 = DP × TP × PP
总 GPU 数 = $N_{DP} \times N_{TP} \times N_{PP}$
典型配置(以 128 GPUs 训 70B 为例):
- TP = 8(机内 8 卡 NVLink)
- PP = 4(4 个 stage 跨 4 台机器)
- DP = 4(4 路数据并行)
- Global Batch Size = DP × micro_batch × grad_accum_steps
| 并行方式 | 切什么 | 通信模式 | 适合场景 |
|---|---|---|---|
| DP / ZeRO | 数据(batch维) | AllReduce / ReduceScatter | 任意规模,基础必选 |
| TP | 权重矩阵(hidden维) | AllReduce(高频) | 机内 NVLink |
| PP | 层(layer维) | P2P Send/Recv(低频) | 跨机 |
| SP (Sequence Parallel) | 序列(seq维) | AllGather / ReduceScatter | 长序列训练 |
Sequence Parallelism (SP)
Megatron-LM v3 引入的 SP:在 TP 的基础上,对 LayerNorm 和 Dropout 等非 TP 区域也沿序列维度切分,进一步节省激活值显存。SP 不增加额外通信,只是将 AllReduce 拆成 AllGather + ReduceScatter。
8. LoRA 系列 (参数高效微调)
8.1 LoRA 原理
核心公式
$$h = W_0 x + \Delta W x = W_0 x + \frac{\alpha}{r} B A x$$
其中:
- $W_0 \in \mathbb{R}^{d \times k}$:冻结的预训练权重
- $B \in \mathbb{R}^{d \times r}$:初始化为零
- $A \in \mathbb{R}^{r \times k}$:随机高斯初始化
- $r \ll \min(d, k)$:低秩维度(通常 4~64)
- $\alpha$:缩放因子
为什么 LoRA 有效?
- Aghajanyan et al. 2020 发现:微调时的权重更新矩阵 $\Delta W$ 具有低内在维度 (low intrinsic dimensionality)
- 全量微调 7B 模型需要 14GB+ 梯度空间,LoRA 只需训 $r \times (d + k)$ 参数
- 训练完毕后 $W_0 + BA$ 可合并,推理无额外延迟
LoRA 关键设计选择
| 选项 | 建议 | 原因 |
|---|---|---|
| 应用层 | Q, K, V, O + MLP (gate, up, down) | 论文验证同时加更多层效果更好 |
| rank r | 8~64 | 太小表达力不够,太大接近全量微调 |
| alpha | 通常 = r 或 2r | 实际学习率 ∝ alpha/r |
| B 初始化 | 零矩阵 | 保证训练初期 ΔW=0,不破坏预训练 |
8.2 QLoRA
QLoRA 的核心贡献:在4-bit 量化的冻结基座上做 LoRA 微调。
QLoRA 三大技术
- 4-bit NormalFloat (NF4):信息论最优的 4-bit 量化格式,假设权重近似正态分布
- Double Quantization:对量化常数(scale factor)再做一次量化,节省约 0.37 bit/param
- Paged Optimizers:利用 NVIDIA 统一内存,在 GPU 显存不足时自动 offload 到 CPU
效果:单张 A100-48GB 即可微调 65B 模型(QLoRA 论文实验),全量微调需要 >780GB。
面试要点:QLoRA 前向过程
$$h = \text{Dequantize}(W_{\text{NF4}}) \cdot x + \frac{\alpha}{r} B A x$$
反向传播时只更新 B 和 A 的梯度,$W_{\text{NF4}}$ 始终冻结且以 4-bit 存储。
8.3 DoRA 与 rsLoRA
DoRA (Weight-Decomposed Low-Rank Adaptation):
- 将权重分解为幅度 (magnitude) 和方向 (direction) 两部分
- LoRA 只更新方向部分,单独学习幅度
- 更接近全量微调的更新模式
rsLoRA (Rank-Stabilized LoRA):
- 将缩放因子从 $\alpha/r$ 改为 $\alpha/\sqrt{r}$
- 使得不同 rank 下的有效学习率更稳定
- 允许使用更大的 rank 而不需要重新调学习率
9. Flash Attention
9.1 Flash Attention v1:IO-Aware 算法
核心问题:标准 Attention 需要将 $N \times N$ 的注意力矩阵写入 HBM,显存 $O(N^2)$,且 HBM 读写是瓶颈。
GPU 内存层次
| 存储 | 容量 | 带宽 |
|---|---|---|
| SRAM (on-chip) | ~20 MB (A100) | ~19 TB/s |
| HBM (显存) | 40/80 GB | ~2 TB/s |
SRAM 比 HBM 快约 10×,但容量小 1000×。Flash Attention 的目标:尽量在 SRAM 中完成 Attention 计算。
Flash Attention v1 的核心技术:
- Tiling(分块计算):将 Q, K, V 切成小块,逐块加载到 SRAM 中计算
- Online Softmax:在分块计算时,维护 running max 和 running sum,无需完整行的 logits
- 计算 $m_{\text{new}} = \max(m_{\text{old}}, \max(\text{current\_block}))$
- 用 correction factor $e^{m_{\text{old}} - m_{\text{new}}}$ 修正之前的累积值
- 不存中间矩阵:$S = QK^T$ 和 $P = \text{softmax}(S)$ 都不写入 HBM
- 反向传播时重算:不保存 $P$,反向时从 Q, K, V 和输出 O 重算(用保存的 $m, \ell$ 辅助)
FA1 复杂度对比
| 标准 Attention | Flash Attention | |
|---|---|---|
| 计算量 (FLOPs) | $O(N^2 d)$ | $O(N^2 d)$(相同) |
| HBM 访问量 | $O(N^2 + Nd)$ | $O(N^2 d / M)$($M$ = SRAM 大小) |
| 额外显存 | $O(N^2)$ | $O(N)$(只存 m, ℓ) |
9.2 Flash Attention v2
在 v1 基础上的关键优化:
- 减少非矩阵乘法 FLOPs:
- v1 在内循环中做 rescaling,v2 把 rescaling 移到外循环末尾
- 减少了约 50% 的非 matmul 操作(non-matmul FLOPs 在 GPU 上效率很低)
- 优化并行度:
- v1 在 batch × heads 维度并行;v2 额外在序列维度并行
- 长序列时 GPU 利用率更高
- 优化 warp 分工:
- v1:4 个 warp 各算部分 K/V,需要 shared memory 同步
- v2:4 个 warp 分 Q 的不同行,各自独立计算,无需同步
FA2 实际加速效果
相比 FA1 加速约 2×,达到 A100 理论峰值 FLOPS 的 50-73%。已成为所有主流 LLM 训练/推理框架的默认 Attention 实现。
9.3 Flash Attention v3
针对 Hopper 架构 (H100) 的进一步优化:
- Warp specialization:将 warp 分为 producer(负责数据搬运)和 consumer(负责计算),异步执行
- 利用 TMA (Tensor Memory Accelerator):硬件加速的异步数据加载
- FP8 支持:利用 H100 的 FP8 Tensor Core,在精度可接受时进一步加速
- Pingpong scheduling:减少 shared memory 上的 bank conflict
10. 混合精度训练
10.1 FP16 与 BF16
| 格式 | 符号位 | 指数位 | 尾数位 | 范围 | 精度 |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | 大 | 高 |
| FP16 | 1 | 5 | 10 | 小(±65504) | 中 |
| BF16 | 1 | 8 | 7 | 同 FP32 | 低 |
| FP8 (E4M3) | 1 | 4 | 3 | 较小 | 很低 |
| FP8 (E5M2) | 1 | 5 | 2 | 较大 | 极低 |
为什么 BF16 成为 LLM 训练的主流?
- BF16 指数位与 FP32 相同(8位),不会溢出
- FP16 指数位只有 5 位,梯度值容易 overflow/underflow
- BF16 无需 loss scaling,训练更稳定
- 代价:尾数精度低(7 bit vs FP16 的 10 bit),累积误差略大
- 从 A100 起 NVIDIA GPU 均原生支持 BF16 Tensor Core
10.2 Loss Scaling
FP16 训练时梯度容易 underflow(太小的值变成 0),Loss Scaling 的做法:
- 前向计算 loss 后,乘以一个大的 scale factor(如 1024 或动态值)
- 反向传播时梯度也被 scale 放大
- 更新前将梯度除以 scale factor 恢复
- Dynamic Loss Scaling:自动调整 scale,遇到 NaN/Inf 时减小,一段时间没问题就增大
11. 梯度累积与梯度裁剪
梯度累积 (Gradient Accumulation)
等效 batch_size = micro_batch_size × grad_accum_steps × num_gpus
每个 micro-batch 只做前向+反向(累加梯度),不更新参数。累积 K 步后才做一次 optimizer.step()。
作用:在显存不够用大 batch 时,模拟大 batch 训练效果。
梯度裁剪 (Gradient Clipping)
$$g \leftarrow g \cdot \frac{\text{max\_norm}}{\max(\|g\|, \text{max\_norm})}$$
当梯度范数超过阈值时,等比缩小所有梯度。
作用:防止梯度爆炸导致训练不稳定(loss spike)。大模型训练中通常设置 max_norm=1.0。
12. 面试真题集
Q1:DDP 和 DeepSpeed ZeRO 有什么区别?什么时候用哪个? 高频
核心区别:
- DDP:每卡存完整模型状态(参数+梯度+优化器),通过 AllReduce 同步梯度
- ZeRO:将模型状态分片,消除冗余存储,按 Stage 1/2/3 逐步切分
选择策略:
- 模型能放进单卡 → DDP(最简单高效)
- 模型放不进单卡但放得进单卡除去激活 → ZeRO Stage 2
- 模型完全放不进单卡 → ZeRO Stage 3 或 TP+PP
通信开销:Stage 1/2 和 DDP 相同,Stage 3 多约 50%。
Q2:ZeRO Stage 3 和 FSDP 的区别?选哪个? 经典
相似:原理基本一致,都是参数+梯度+优化器全分片。
差异:
- FSDP 是 PyTorch 原生,API 更 Pythonic,和 torch.compile / torch.amp 集成好
- DeepSpeed 功能更丰富(NVMe offload、ZeRO-Infinity、更多调优选项)
- FSDP 以 FlatParameter 为单位,DeepSpeed 以参数为单位
选择:纯 PyTorch 生态选 FSDP;需要极致优化或 offload 选 DeepSpeed。
Q3:TP 和 PP 的区别是什么?为什么 TP 要放机内? 高频
TP (Tensor Parallelism):
- 切的是权重矩阵(层内切分)
- 每层都有 AllReduce 通信,频率极高
- 对延迟敏感,必须高带宽互联 → 放机内 NVLink(~600 GB/s)
PP (Pipeline Parallelism):
- 切的是层(层间切分)
- 只在 stage 边界做点对点通信,频率低
- 对带宽要求低 → 可以跨机(InfiniBand ~100-400 Gb/s)
关键总结:TP 通信频率高+数据量大→要高带宽低延迟;PP 通信频率低+只传激活→可容忍跨机延迟。
Q4:Flash Attention 为什么能加速?它改变了计算复杂度吗? 高频
不改变计算复杂度(仍然是 $O(N^2d)$ FLOPs),改变的是 IO 复杂度。
核心思路:
- 标准实现:Q, K → 算 $N\times N$ 矩阵 → 写入 HBM → 读回做 softmax → 写入 HBM → 读回乘 V
- Flash Attention:分块加载到 SRAM → 在 SRAM 中完成 QK+softmax+乘V → 只写最终输出到 HBM
加速来源:
- 减少 HBM 读写次数(IO-aware)
- 避免存储 $N\times N$ 中间矩阵
- 用 online softmax 实现分块处理
一句话:Flash Attention 是 IO 优化而非算法优化,它让 Attention 从 memory-bound 变得更接近 compute-bound。
Q5:LoRA 为什么有效?rank 怎么选?为什么 B 初始化为零? 高频
为什么有效:
- 微调时的权重变化 $\Delta W$ 具有低内在维度
- 用低秩矩阵 $BA$ 近似 $\Delta W$ 就够了
- 类比:预训练学了"通用能力",微调只需在少量"方向"上调整
rank 选择:
- 简单任务(单领域 QA):r=8 通常够
- 复杂任务(多任务/代码生成):r=32~64
- 过大的 rank 收益递减,且接近全量微调
B 初始化为零:
- 保证训练开始时 $\Delta W = BA = 0$
- 不破坏预训练学到的表示
- 从预训练权重出发渐进微调
Q6:QLoRA 相比 LoRA 做了什么?4-bit 量化精度够吗? 经典
QLoRA 三大技术贡献:
- NF4 量化:针对正态分布权重设计的 4-bit 格式,信息论最优
- Double Quantization:对量化参数(scale)再量化,额外节省 0.37 bit/param
- Paged Optimizers:利用 CUDA 统一内存自动 CPU offload
精度够吗:
- QLoRA 论文证明:16-bit LoRA ≈ QLoRA 效果,差距 < 1%
- 因为梯度信号通过 LoRA 的 B、A 矩阵传播(这两个是全精度的),量化基座只参与前向
- 但要注意:推理时如果继续用 4-bit,可能在某些精细任务上有损
Q7:为什么 BF16 比 FP16 更适合大模型训练? 经典
根本原因:动态范围。
- FP16 只有 5 位指数 → 范围 ≈ $6\times10^{-8}$ 到 $65504$
- BF16 有 8 位指数 → 范围与 FP32 相同
- 大模型训练中梯度值域很广,FP16 容易 overflow(梯度爆炸)或 underflow(梯度消失)
实际影响:
- FP16 需要 loss scaling 来避免 underflow,增加复杂性
- BF16 无需 loss scaling,训练代码更简单
- BF16 精度略低(7 bit 尾数 vs 10 bit),但对 LLM 训练影响很小
Q8:梯度累积和增大 batch size 有什么区别?完全等价吗?
数学上近似等价,但有细微差别:
- 等价部分:梯度的期望相同,最终参数更新方向一致
- 不等价部分:
- BatchNorm 统计量不同(每个 micro-batch 独立算 BN,但 LLM 用 LayerNorm 无此问题)
- 学习率调度:累积 K 步才更新一次,等效 step 数变为 1/K
- 梯度裁剪时机:如果对每个 micro-batch 单独裁剪 vs 累积后裁剪,行为不同
最佳实践:累积后再做梯度裁剪(即对完整梯度做 clip),lr scheduler 按实际更新步数(而非前向步数)计算。
Q9:训练 70B 模型需要多少卡?怎么配置并行策略? 系统设计
显存估算(70B, BF16 + Adam):
- 参数:$70\text{B} \times 2 = 140$ GB
- 梯度:$140$ GB
- 优化器状态(FP32 master + m + v):$70\text{B} \times 12 = 840$ GB
- 总计:约 1120 GB + 激活值
典型配置(128× A100-80GB):
- TP = 8(单机 8 卡,NVLink 互联)
- PP = 4(4 个 pipeline stage,跨 4 机)
- DP = 4(4 路数据并行 + ZeRO Stage 1)
- Sequence Length = 4096, Micro-batch = 1~2, Grad Accum = 8~16
优化清单:
- Flash Attention 2(节省激活值显存 + 加速)
- Gradient Checkpointing(重算部分激活)
- BF16 混合精度
- Overlapped communication(通信计算重叠)
Q10:Gradient Checkpointing 的原理?空间和时间的 trade-off?
原理:
- 正常训练:前向保存所有层的激活值 → 反向时直接用
- Checkpointing:只保存某些层(checkpoint 层)的激活值 → 反向到非 checkpoint 层时从最近的 checkpoint 重算
Trade-off:
- 空间:激活值显存从 $O(L)$ 降到 $O(\sqrt{L})$($L$ 为层数,均匀选 checkpoint)
- 时间:额外增加约 33% 的前向计算(重算一次)
实践:几乎所有大模型训练都开启 checkpointing。Selective checkpointing(只重算 Attention 不重算 MLP)可以在空间和时间之间找到更好的平衡。
Q11:Ring AllReduce 的原理?为什么通信量与卡数无关?
Ring AllReduce 分两个阶段:
- Reduce-Scatter:N-1 轮,每轮每卡发 $D/N$ 数据给下一卡并做聚合,最终每卡持有 $1/N$ 完整结果
- All-Gather:N-1 轮,每轮每卡发 $D/N$ 数据给下一卡,最终每卡持有完整结果
总通信量:每卡收发 $2 \times \frac{N-1}{N} \times D \approx 2D$ 数据。
为什么与卡数无关:虽然轮数增加了(N-1 轮),但每轮传输量减小了($D/N$),总量恒定。瓶颈在于带宽而非延迟(大数据量时延迟占比很小)。
Q12:Megatron-LM 的 Sequence Parallelism 是什么?和 TP 是什么关系?
问题背景:
- TP 切分了 Attention 和 MLP 的权重矩阵
- 但 LayerNorm、Dropout 等操作不在 TP 范围内,每卡仍需存完整激活
SP 的做法:
- 在非 TP 区域(LN、Dropout),将激活值沿序列维度切分
- TP 区域前做 AllGather 收集完整输入,TP 区域后做 ReduceScatter 分发
- 本质上把 TP 区域的 AllReduce 拆成 AllGather + ReduceScatter,总通信量不变
收益:非 TP 区域的激活值显存减少为 $1/N_{TP}$,无额外通信开销。
Q13:训练中遇到 loss spike 怎么处理? 实战
常见原因:
- 数据问题:某 batch 数据异常(超长文本、乱码、重复)
- 学习率问题:warmup 不够 / lr 过大
- 梯度爆炸:某层梯度异常放大
- 数值问题:FP16 overflow、BF16 精度丢失积累
排查步骤:
- 检查 gradient norm 是否暴增
- 检查当前 batch 数据是否异常
- 检查 loss scale(FP16 场景)
- 观察是否可自恢复(偶发 vs 持续)
处理策略:
- 偶发 spike 且自恢复 → 不处理或跳过该 batch
- 持续不恢复 → 回滚到上一个 checkpoint,调低 lr 或增大 grad clip
- 预防:更严格的数据清洗、gradient clipping = 1.0、z-loss 正则
Q14:LoRA 和全量微调的效果差距大吗?什么时候必须全量微调?
效果对比:
- 大多数下游任务:LoRA 效果接近全量微调(差距 < 1-2%)
- 特别是当基座模型足够强且任务简单时
必须全量微调的场景:
- 持续预训练(Continual Pre-training):需要大量新知识注入
- 跨语言/跨领域大幅迁移
- 对齐训练的 SFT 阶段(工业界通常全量)
- LoRA 无法达到目标效果且已尝试增大 rank
工业界实践:SFT 阶段倾向全量微调;后续的偏好学习 (DPO/RLHF) 阶段常用 LoRA 以节省资源。
Q15:如何计算训练的 token 吞吐量?MFU 是什么?
Token 吞吐量:
$$\text{Throughput} = \frac{\text{global\_batch\_size} \times \text{seq\_len}}{\text{step\_time}}$$
MFU (Model FLOPs Utilization):
$$\text{MFU} = \frac{\text{实际 FLOPs/s}}{\text{硬件理论峰值 FLOPs/s}}$$
其中每个 token 的 FLOPs ≈ $6 \times \text{params}$(前向 2×params, 反向 4×params)
工业标准:
- MFU 40-55%:正常水平
- MFU > 55%:优秀
- MFU < 30%:存在严重瓶颈(通信/IO/bubble)
典型数据(LLaMA-65B on 2048× A100):MFU ~43%,约 380 tokens/s/GPU