NLP & LLM 课程学习笔记

Lecture 08: RLHF — SFT, Reward Model, PPO

核心主题:InstructGPT 三阶段流水线、监督微调、奖励模型、PPO 对齐

1. InstructGPT 三阶段流水线 核心

┌─────────────────────────────────────────────────────────────────────────┐
│                    InstructGPT Training Pipeline                        │
├─────────────────┬──────────────────────┬────────────────────────────────┤
│   Stage 1: SFT  │  Stage 2: Reward Model│  Stage 3: PPO (RLHF)         │
│                 │                      │                                │
│  Pretrained LLM │  SFT Model (frozen)  │  RM (frozen) + SFT (ref)      │
│       │         │       │              │       │          │             │
│       ▼         │       ▼              │       ▼          ▼             │
│  Human demos    │  Human preferences   │  RL optimization               │
│  (prompt,resp)  │  (x, y_w ≻ y_l)     │  max r(x,y) - β·KL            │
│       │         │       │              │       │                        │
│       ▼         │       ▼              │       ▼                        │
│  SFT Model      │  Reward Model r_θ    │  PPO Policy π_φ               │
└─────────────────┴──────────────────────┴────────────────────────────────┘

数据量:    ~13K demos        ~33K comparisons      ~31K prompts
标注成本:  高 (写完整回复)    中 (排序两个回复)      无 (自动采样)
核心思想

2. Stage 1: Supervised Fine-Tuning (SFT)

2.1 SFT Loss — 仅计算 response tokens

SFT 损失函数: $$\mathcal{L}_{\text{SFT}} = -\sum_{t \in \text{response}} \log p_\theta(x_t \mid x_{<t})$$

关键细节

$$\text{完整序列: } \underbrace{[\text{BOS}] \text{ instruction tokens}}_{\text{label} = -100} \underbrace{\text{response tokens } [\text{EOS}]}_{\text{计算 loss}}$$

2.2 Prompt Template (Alpaca 格式)

Below is an instruction that describes a task, paired with an input
that provides further context. Write a response that appropriately
completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
{output}

模板设计原则

2.3 常用 SFT 数据集

数据集 规模 来源 特点
Alpaca 52K GPT-3.5 生成 Self-Instruct, 成本低
ShareGPT 90K 用户对话收集 多轮, 真实用户分布
Dolly 15K 人工标注 Databricks 员工标注
OASST1 160K 众包对话树 多轮 + 偏好标注
FLAN Collection 15M NLP 任务混合 1800+ 任务, 大规模
InstructGPT demos 13K 人工标注 高质量, OpenAI 内部

2.4 实现代码 (SFTDataset)

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class SFTDataset(Dataset):
    """SFT 数据集:仅在 response 部分计算 loss"""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data  # List of {"instruction": ..., "output": ...}
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 构造 prompt
        prompt = f"### Instruction:\n{item['instruction']}\n\n### Response:\n"
        response = item['output'] + self.tokenizer.eos_token
        
        # 分别 tokenize
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        response_ids = self.tokenizer.encode(response, add_special_tokens=False)
        
        # 拼接
        input_ids = prompt_ids + response_ids
        input_ids = input_ids[:self.max_length]
        
        # 构造 labels: prompt 部分为 -100, response 部分为 token id
        labels = [-100] * len(prompt_ids) + response_ids
        labels = labels[:self.max_length]
        
        # Padding
        pad_len = self.max_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        labels = labels + [-100] * pad_len
        attention_mask = [1] * (self.max_length - pad_len) + [0] * pad_len
        
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }


# 训练循环
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token

training_args = TrainingArguments(
    output_dir="./sft_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    bf16=True,
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=sft_dataset,
)
trainer.train()

2.5 训练结果

Epoch Train Loss Eval Loss PPL
1 1.82 1.75 5.75
2 1.45 1.52 4.57
3 1.21 1.58 4.86
注意:Epoch 3 的 eval loss 上升,说明开始过拟合。InstructGPT 论文建议 SFT 只训练 1 epoch,避免过拟合标注者的写作风格。

3. Stage 2: Reward Model

3.1 Bradley-Terry Loss 核心

目标:学习一个标量奖励函数 $r_\theta(x, y)$,使得人类偏好的回复获得更高分。

Bradley-Terry 偏好模型: $$P(y_w \succ y_l \mid x) = \sigma\left(r_\theta(x, y_w) - r_\theta(x, y_l)\right)$$
RM 损失函数: $$\mathcal{L}_{\text{RM}} = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}}\left[\log \sigma\left(r_\theta(x, y_w) - r_\theta(x, y_l)\right)\right]$$

直觉解读

3.2 RM 架构代码 (RewardModel)

import torch
import torch.nn as nn
from transformers import AutoModel

class RewardModel(nn.Module):
    """奖励模型:基于 SFT 模型初始化,移除 LM head,添加 scalar head"""
    
    def __init__(self, base_model_name, sft_checkpoint=None):
        super().__init__()
        # 加载 base transformer (不需要 LM head)
        self.backbone = AutoModel.from_pretrained(base_model_name)
        hidden_size = self.backbone.config.hidden_size
        
        # 标量奖励头:无 bias (InstructGPT 设计)
        self.reward_head = nn.Linear(hidden_size, 1, bias=False)
        
        # 从 SFT checkpoint 初始化 backbone
        if sft_checkpoint:
            sft_state = torch.load(sft_checkpoint, map_location="cpu")
            self.backbone.load_state_dict(sft_state, strict=False)
    
    def forward(self, input_ids, attention_mask):
        """返回序列末尾 token 的 scalar reward"""
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden_states = outputs.last_hidden_state  # (B, T, H)
        
        # 取最后一个非 pad token 的 hidden state
        # 方法: 用 attention_mask 找到每个序列的最后位置
        seq_lengths = attention_mask.sum(dim=1) - 1  # (B,)
        batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
        last_hidden = hidden_states[batch_idx, seq_lengths]  # (B, H)
        
        reward = self.reward_head(last_hidden).squeeze(-1)  # (B,)
        return reward


class RMTrainer:
    """RM 训练器:实现 Bradley-Terry pairwise loss"""
    
    def compute_loss(self, model, batch):
        # batch 包含 chosen 和 rejected 对
        r_chosen = model(
            input_ids=batch["chosen_input_ids"],
            attention_mask=batch["chosen_attention_mask"],
        )
        r_rejected = model(
            input_ids=batch["rejected_input_ids"],
            attention_mask=batch["rejected_attention_mask"],
        )
        
        # Bradley-Terry loss
        loss = -torch.log(torch.sigmoid(r_chosen - r_rejected)).mean()
        
        # 记录 accuracy
        accuracy = (r_chosen > r_rejected).float().mean()
        
        return loss, {"accuracy": accuracy.item()}

3.3 关键设计选择

设计 选择 原因
初始化 从 SFT model 初始化 RM 需要理解语言,SFT 已具备
训练轮数 仅 1 epoch 避免过拟合偏好标注噪声
Reward head Linear(H, 1, bias=False) 无 bias 避免 reward 偏移
取值位置 最后一个 token 看到完整回复后给出评分
模型大小 6B (InstructGPT) 比 policy (175B) 小,节省计算
数据格式 Pairwise comparison 相对排序比绝对评分一致性更高

3.4 偏好数据集

数据集 规模 标注方式 特点
InstructGPT comparisons 33K 人工排序 4-9 个回复 高质量, 内部标注团队
Anthropic HH-RLHF 170K 人工选择 chosen/rejected 有害性 + 有用性
OpenAssistant (OASST) 90K 众包排名 多语言, 开源
UltraFeedback 64K GPT-4 评分 AI 标注, 成本低
Stanford SHP 385K Reddit 投票 自然偏好信号

4. Stage 3: RLHF via PPO 核心

4.1 RLHF Objective

RLHF 优化目标: $$\text{objective}(\phi) = \mathbb{E}_{(x,y) \sim \pi_\phi^{RL}}\left[r_\theta(x,y) - \beta \log\frac{\pi_\phi^{RL}(y|x)}{\pi^{SFT}(y|x)}\right]$$
符号 含义 角色
$\pi_\phi^{RL}$ 当前 RL 策略(actor) 正在优化的模型
$\pi^{SFT}$ SFT 模型(frozen reference) KL 锚点,防止偏移
$r_\theta(x,y)$ 奖励模型打分 衡量回复质量
$\beta$ KL 惩罚系数 控制探索与保守的平衡
$\log\frac{\pi_\phi^{RL}}{\pi^{SFT}}$ Per-token KL 散度 惩罚偏离参考策略

等价形式(展开 KL):

$$\text{objective}(\phi) = \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\phi}\left[r_\theta(x,y)\right] - \beta \cdot D_{KL}\left(\pi_\phi^{RL} \| \pi^{SFT}\right)$$

4.2 PPO Clipped Surrogate

PPO-Clip 目标: $$L^{CLIP}(\phi) = -\mathbb{E}_t\left[\min\left(r_t(\phi) \hat{A}_t,\; \text{clip}\left(r_t(\phi),\; 1-\varepsilon,\; 1+\varepsilon\right) \hat{A}_t\right)\right]$$
符号 定义 说明
$r_t(\phi)$ $\frac{\pi_\phi(a_t|s_t)}{\pi_{\phi_{\text{old}}}(a_t|s_t)}$ 重要性采样比 (IS ratio)
$\hat{A}_t$ GAE advantage estimate 当前 action 相对 baseline 的优势
$\varepsilon$ 0.2 (常用值) clip 范围,限制策略更新幅度
$\text{clip}(r_t, 1-\varepsilon, 1+\varepsilon)$ 将比值裁剪到 [0.8, 1.2] 防止单步更新过大

PPO-Clip 直觉

4.3 ActorCritic 架构代码

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM

class ActorCritic(nn.Module):
    """PPO Actor-Critic: Actor (policy) + Critic (value function)"""
    
    def __init__(self, model_name, sft_checkpoint=None):
        super().__init__()
        # Actor: 语言模型策略 π_φ
        self.actor = AutoModelForCausalLM.from_pretrained(model_name)
        if sft_checkpoint:
            self.actor.load_state_dict(torch.load(sft_checkpoint))
        
        # Critic: 价值函数 V(s),共享 backbone 或独立
        self.critic_backbone = AutoModelForCausalLM.from_pretrained(model_name)
        hidden_size = self.critic_backbone.config.hidden_size
        self.value_head = nn.Linear(hidden_size, 1)
    
    def get_policy(self, input_ids, attention_mask):
        """返回 token 级别的 log probabilities"""
        outputs = self.actor(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        logits = outputs.logits  # (B, T, V)
        log_probs = torch.log_softmax(logits, dim=-1)
        return log_probs
    
    def get_value(self, input_ids, attention_mask):
        """返回每个 token 位置的 value estimate"""
        outputs = self.critic_backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        hidden = outputs.hidden_states[-1]  # (B, T, H)
        values = self.value_head(hidden).squeeze(-1)  # (B, T)
        return values
    
    def forward(self, input_ids, attention_mask):
        log_probs = self.get_policy(input_ids, attention_mask)
        values = self.get_value(input_ids, attention_mask)
        return log_probs, values

4.4 PPO 训练循环代码

import torch
import torch.nn.functional as F

class PPOTrainer:
    """PPO 训练器 (简化版)"""
    
    def __init__(self, actor_critic, ref_model, reward_model, 
                 kl_coeff=0.1, clip_eps=0.2, gamma=1.0, lam=0.95):
        self.ac = actor_critic
        self.ref_model = ref_model       # frozen SFT model
        self.rm = reward_model           # frozen reward model
        self.kl_coeff = kl_coeff         # β
        self.clip_eps = clip_eps         # ε
        self.gamma = gamma
        self.lam = lam                   # GAE lambda
    
    @torch.no_grad()
    def rollout(self, prompts, max_new_tokens=256):
        """Phase 1: 用当前策略采样回复"""
        responses = []
        for prompt_ids in prompts:
            # 自回归生成
            output = self.ac.actor.generate(
                input_ids=prompt_ids.unsqueeze(0),
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
            )
            responses.append(output[0])
        return responses
    
    @torch.no_grad()
    def compute_rewards(self, prompt_ids, response_ids):
        """计算 reward = RM score - β * per-token KL"""
        full_ids = torch.cat([prompt_ids, response_ids], dim=1)
        attention_mask = (full_ids != 0).long()
        
        # RM score (sequence-level)
        rm_score = self.rm(full_ids, attention_mask)
        
        # Per-token KL penalty
        policy_logprobs = self.ac.get_policy(full_ids, attention_mask)
        ref_logprobs = self.ref_model(full_ids, attention_mask).log_softmax(-1)
        
        # KL(π_φ || π_ref) per token
        kl_per_token = (policy_logprobs.exp() * (policy_logprobs - ref_logprobs)).sum(-1)
        
        # Token-level reward: 0 everywhere except last token gets RM score
        # minus KL penalty at every token
        rewards = -self.kl_coeff * kl_per_token
        # Add RM score at the end of response
        response_end = attention_mask.sum(-1) - 1
        rewards[0, response_end] += rm_score
        
        return rewards
    
    def compute_gae(self, rewards, values, mask):
        """Generalized Advantage Estimation"""
        advantages = torch.zeros_like(rewards)
        last_gae = 0
        
        for t in reversed(range(rewards.size(1))):
            if t == rewards.size(1) - 1:
                next_value = 0
            else:
                next_value = values[:, t + 1]
            
            delta = rewards[:, t] + self.gamma * next_value - values[:, t]
            advantages[:, t] = last_gae = delta + self.gamma * self.lam * last_gae
            advantages[:, t] *= mask[:, t]
        
        returns = advantages + values
        return advantages, returns
    
    def ppo_update(self, batch, n_epochs=4):
        """Phase 2: PPO policy gradient update"""
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        old_log_probs = batch["old_log_probs"]
        advantages = batch["advantages"]
        returns = batch["returns"]
        response_mask = batch["response_mask"]
        
        for epoch in range(n_epochs):
            # Current policy log probs
            log_probs, values = self.ac(input_ids, attention_mask)
            
            # Gather log probs of actual tokens
            action_log_probs = torch.gather(
                log_probs[:, :-1], 2, input_ids[:, 1:].unsqueeze(-1)
            ).squeeze(-1)
            
            # Importance sampling ratio
            ratio = torch.exp(action_log_probs - old_log_probs)
            
            # PPO Clipped objective
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
            policy_loss = -torch.min(surr1, surr2)
            policy_loss = (policy_loss * response_mask).sum() / response_mask.sum()
            
            # Value loss
            value_loss = F.mse_loss(values[:, :-1] * response_mask, 
                                     returns * response_mask)
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.ac.parameters(), 1.0)
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        return {"policy_loss": policy_loss.item(), "value_loss": value_loss.item()}

4.5 KL 惩罚的作用

为什么需要 KL 惩罚? — 防止 Reward Hacking

没有 KL 约束时,策略会找到 RM 的漏洞(adversarial examples)获得高分但输出垃圾:

现象 无 KL 约束 有 KL 约束
输出质量 重复、无意义但 RM 给高分 流畅、有意义的回答
RM score 极高(但虚假) 适度高(真实偏好)
KL 距离 无限增大 受控 (通常 < 10 nats)
多样性 模式坍缩 保持多样性

$\beta$ 的选择

4.6 训练结果

训练步数 RM Score KL (nats) Win Rate vs SFT
0 (SFT init) 2.1 0.0 50%
500 3.2 2.8 62%
1000 3.8 5.1 68%
2000 4.1 7.3 71%
5000 4.3 9.5 72%
观察:RM score 和 win rate 同步增长但边际递减;KL 持续增大说明策略在不断偏离 SFT。需要在 "足够好" 时停止训练。

5. 完整超参数表

超参数 SFT RM PPO
Base model GPT-3 175B GPT-3 6B SFT 175B
Learning rate 2e-5 9e-6 1.5e-5 (actor) / 5e-6 (critic)
Batch size 32 64 512 (rollout) / 64 (update)
Epochs 1 (16 for small) 1 4 (PPO epochs per batch)
Scheduler Cosine Cosine Cosine
Max seq length 2048 2048 512 (prompt) + 256 (response)
Warmup ratio 3% 5% 0
KL coeff ($\beta$) N/A N/A 0.02 (adaptive)
Clip $\varepsilon$ N/A N/A 0.2
GAE $\lambda$ N/A N/A 0.95
Discount $\gamma$ N/A N/A 1.0
Grad clip 1.0 1.0 1.0
数据量 ~13K demos ~33K comparisons ~31K prompts

6. 定性对比

Prompt"Explain the moon landing to a 6 year old in a way that is inspiring."

模型 输出示例 评价
Base GPT-3 "Explain the sun to a 6 year old. Explain gravity to a 6 year old. Explain..." 重复 prompt 模式,无法遵循指令
SFT "So, a long time ago, people wanted to go to the moon. They built a really big rocket and three brave astronauts went inside..." 正确格式,但平淡,缺乏感染力
PPO (RLHF) "Imagine you could jump SO high that you could touch the stars! Well, some very brave people actually did something like that..." 生动、有感染力、适合儿童
关键观察

7. 关键设计选择

# 设计选择 原因
1 RM 从 SFT 初始化 需要语言理解能力来评估回复质量
2 RM 只训 1 epoch 偏好标注有噪声,多 epoch 会过拟合噪声
3 RM 比 policy 小 PPO 中需要同时加载 4 个模型,RM 小可节省显存
4 Pairwise ranking 而非 pointwise 人类对相对比较比绝对评分更一致(标注者间一致性更高)
5 自适应 KL 系数 固定 $\beta$ 难以平衡不同训练阶段;自适应保持 KL 在目标范围
6 混合 pretrain 梯度 InstructGPT 在 PPO 阶段混合 pretrain loss 防止遗忘通用能力

8. 已知问题与局限

问题 描述 缓解方法
Reward Hacking 策略找到 RM 漏洞获得虚假高分(如冗长回答、重复讨好词) KL 惩罚、增大 RM、ensemble RM
Distribution Shift RM 在 SFT 输出上训练,但 PPO 策略的输出不断变化,RM 评估不可靠 迭代训练(重新收集偏好数据)、PPO-max
计算成本高 PPO 需同时加载 4 个模型(actor, critic, ref, RM),显存需求 ~4x LoRA、共享 backbone、DeepSpeed ZeRO
训练不稳定 PPO 对超参数敏感,reward 可能突然坍缩 梯度裁剪、小学习率、reward normalization
标注者偏见 RM 捕获标注者偏好(而非真实 "正确性") 多样化标注团队、clear guidelines
对齐税 (Alignment Tax) RLHF 后在某些 NLP benchmarks 上性能下降 混合 pretrain loss、保守 KL

9. 数学公式速查

1. SFT Loss: $$\mathcal{L}_{\text{SFT}} = -\sum_{t \in \text{response}} \log p_\theta(x_t \mid x_{<t})$$
2. Bradley-Terry (RM Loss): $$\mathcal{L}_{\text{RM}} = -\mathbb{E}_{(x, y_w, y_l)}\left[\log \sigma\left(r_\theta(x, y_w) - r_\theta(x, y_l)\right)\right]$$
3. RLHF Objective: $$\max_\phi \; \mathbb{E}_{x, y \sim \pi_\phi}\left[r_\theta(x,y) - \beta \log\frac{\pi_\phi(y|x)}{\pi^{SFT}(y|x)}\right]$$
4. PPO Clipped Surrogate: $$L^{CLIP} = -\mathbb{E}_t\left[\min\left(r_t \hat{A}_t,\; \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon)\hat{A}_t\right)\right]$$
5. KL 散度 (per-token): $$D_{KL}(\pi_\phi \| \pi^{SFT}) = \mathbb{E}_{y \sim \pi_\phi}\left[\sum_t \log\frac{\pi_\phi(y_t|y_{<t}, x)}{\pi^{SFT}(y_t|y_{<t}, x)}\right]$$

10. 延伸阅读

核心要点总结

  1. 三阶段流水线是核心框架:SFT (学格式) → RM (学偏好) → PPO (优化偏好),每阶段解决不同问题
  2. SFT 只对 response 计算 loss:prompt 标记为 -100,本质是条件语言模型微调
  3. Bradley-Terry 模型将偏好转为分类:$\sigma(r_w - r_l)$ 将标量奖励差映射为偏好概率
  4. KL 惩罚是 RLHF 成功的关键:$\beta \cdot D_{KL}(\pi_\phi \| \pi^{SFT})$ 防止 reward hacking 和模式坍缩
  5. PPO-Clip 保证训练稳定性:限制每步策略更新幅度在 $[1-\varepsilon, 1+\varepsilon]$ 内
  6. PPO 的 4 模型问题是工程瓶颈:actor + critic + ref + RM 需要 ~4x 显存,催生了 DPO 等轻量替代方案