Press "Enter" to skip to content

从人类反馈中强化学习(RLHF)

简化说明

也许你已经听说过这个技术,但你对它并不完全理解,特别是PPO部分。这个说明可能会有所帮助。

我们将重点讨论文本到文本的语言模型,例如GPT-3、BLOOM和T5。我们不涉及仅编码器的模型,例如BERT。

本博客文章是同一作者的详细说明的改编。

基于人类反馈的强化学习(RLHF)在ChatGPT中得到了成功应用,因此其受欢迎程度大幅增加。

RLHF在两种情况下尤其有用:

  • 你无法创建一个好的损失函数(例如,你如何计算一个度量标准来衡量模型的输出是否有趣?)
  • 你想要用生产数据进行训练,但你无法轻易地对生产数据进行标记(例如,你如何从ChatGPT获取带标签的生产数据?有人需要编写ChatGPT应该回答的正确答案)

RLHF算法的步骤:

  1. 预训练语言模型(LM)
  2. 训练奖励模型
  3. 用RL进行微调语言模型

1-预训练语言模型(LM)

在这一步中,你需要从头开始训练一个语言模型,或者只是使用像GPT-3这样的预训练模型。

一旦你拥有了预训练语言模型,你还可以进行额外的可选步骤,称为监督微调(STF)。这只是获得一些人工标记的(输入,输出)文本对,并对你已有的语言模型进行微调。STF被认为是RLHF的高质量初始化。

在这一步结束时,我们得到了经过训练的主要语言模型,这是我们想要进一步用RLHF进行训练的模型。

图1:我们的预训练语言模型。

2-训练奖励模型

在这一步中,我们希望收集一个(输入文本,输出文本,奖励)三元组的数据集。

在图2中,显示了数据收集流程的表示:使用输入文本数据(最好是生产数据),将其通过模型,并由人为生成的输出文本进行奖励评定。

图2:用于奖励模型训练的数据收集流程。

奖励通常是0到5之间的整数,但也可以是简单的👍/👎体验中的0/1。

图3:ChatGPT中简单的👍/👎奖励收集方式。
图4:更完整的奖励收集体验:模型输出两个文本,人类需要选择哪个更好,并给出总体评分和评论。

通过这个新的数据集,我们将训练另一个语言模型来接收(输入、输出)文本并返回一个奖励值!这将是我们的奖励模型。

这里的主要目标是使用奖励模型来模仿人类的奖励标注,从而能够在没有人类参与的情况下进行离线的RLHF训练。

Figure 5: The trained reward model, that will mimic the rewards given by humans.

3 — 使用RL对LM进行微调

在这一步中,真正发生魔法并引入RL。

这一步的目标是利用奖励模型给出的奖励来训练主模型,你训练好的LM。然而,由于奖励不可微分,我们需要使用RL来构建一个可反向传播到LM的损失函数。

Figure 6: Fine-tuning the main LM using the reward model and the PPO loss calculation.

在管道的开始处,我们将制作我们的LM的一个精确副本并冻结其可训练权重。该模型的副本有助于防止可训练的LM完全改变其权重,并开始输出无意义的文字以欺骗奖励模型。

这就是为什么我们计算冻结和未冻结LM的文本输出概率之间的KL散度损失。

这个KL损失被添加到由奖励模型产生的奖励上。实际上,如果你在生产中训练模型(在线学习),你可以直接用人类奖励得分替换奖励模型。 💡

有了你的奖励和KL损失,现在我们可以应用RL使奖励损失可微分。

为什么奖励不可导?因为它是用接收文本输入的奖励模型计算的。这个文本是通过解码LM的输出对数概率得到的。这个解码过程是不可微分的。

为了使损失可微分,最后,Proximal Policy Optimization(PPO)开始发挥作用!让我们放大一下。

Figure 7: Zoom-in on the RL Update box — PPO loss calculation.

PPO算法计算这样一个损失(将用于对LM进行小调整):

  1. 将“初始概率”设置为“新概率”以进行初始化。
  2. 计算新输出文字概率与初始输出文字概率之间的比率。
  3. 根据公式 loss = -min(ratio * R, clip(ratio, 0.8, 1.2) * R)(其中 R 是之前计算出的reward + KL(或加权平均值,如 0.8 * reward + 0.2 * KL),clip(ratio, 0.8, 1.2) 只是将比率限制在0.8 <= ratio <= 1.2)。注意,0.8 / 1.2 只是常用的超参数值,在这里进行了简化。还要注意,我们希望最大化奖励,这就是为什么我们加入了减号-,这样我们就可以通过梯度下降来最小化损失的否定值。
  4. 通过反向传播损失来更新LM的权重。
  5. 使用新更新的LM计算“新概率”(即新的输出文字概率)。
  6. 重复步骤2至N次(通常N=4)。

就是这样,这就是您在文本到文本语言模型中使用RLHF的方法!

事情可能会变得更加复杂,因为您还可以添加其他损失函数到我所介绍的基本损失函数中,但这是核心实现。

Leave a Reply

Your email address will not be published. Required fields are marked *