目前来说主流的RLHF方向分为两大类:

  • 1 PPO类
  • 2 DPO类

PPO的训练非常耗时,一次需要加载4个模型,Actor,Reward,Critic和Ref model。主要步骤:

  • 1 SFT训练一个Actor模型
  • 2 训练一个Reward模型
  • 3 使用Actor初始化Ref模型,使用Reward初始化Critic模型,进行PPO算法训练Actor和Critic模型。

下面是OPenAI在InstructGPT中对PPO训练步骤:

下面我们说明一下每个模型的作用:

  • Actor:我们想要训练的目标语言模型,产生文本。
  • Critic:预估总收益。
  • Reward:计算即时收益。
  • Ref:参考模型,防止Actor模型训练歪。

PPO – 理论

训练Actor和Reward模型相对来说简单,我们直接进入PPO部分,查看PPO如何定义Actor和Critic的loss,如何开展训练。

原始actor loss

对于上下文St而言,生成At和概率为P(At|St),如果此时Vt>0,即总收益越大则增加P(At|St)的概率;如果Vt<0,即总收益越小则减少P(At|St)的概率。

引入优势

原始的actor loss只关注每一个时刻的总收益,这当然没什么问题,但是这样会让训练难以很快收敛。原始actor loss只要Vt大于0,模型就会优化让P(At|St)更大,但是如果Vt减少了P(At|St)增加这种情况使我们不愿意看到的

如果我们不再只关注总收益Vt,而是关注的是超前的收益趋势,即t时刻,选择At后产生的总收益如果比当前时候的总收益更大,则说明At动作更有价值。此时的loss可以选择更有有益的At的方向。

引入新Reward

原始Reward是由冻结后的Reward模型通过Actor模型产生的At进行打分得到的。

但是实际上我们可以做更多的设计:

  • 当t不等于T时,我们更加关心Actor是否在Ref的约束下生成At。
  • 当t等于T时,我们不仅关心Actor是否在Ref的约束下生成At,同时还关注At的质量,即奖励Rt。

则引入新优势和新Reward之后的actor loss如下:

引入off-policy的actor loss

以上的过程中的actor loss实际上是一种on-policy的训练方法,即模型使用实时的收益进行动作的决策。

这对于大模型来说,太heavy了(每一个原本数据的生成都需要进行4个模型的运行推理),想要训练充分很耗费时间。

off-policy下场了,如果我们可以将模型之前生成的全部数据拿到t时候使用,则增加好几倍的训练样本。

但是这又会出现新的问题,不同时刻的Actor的分布不同,生成的At的分布也不同,怎么解决这个问题呢?

Important-Sample下场了,通过重要性采样的方法使模型不过分偏移。

实际情况下当我们害怕这两个分布偏移太多导致最后的actor loss不稳定,也可以对重要性采样的分布进行裁剪,如下:

2 PPO – 代码实战

2.1 Actor loss

同样跟随上面的步骤:

  • 1 Reward如何得到
  • 2 优势Delta如何得到
  • 3 重要性采样分布Ratio如何得到
  • 4 计算最终actor loss

Reward:

def get_reward_kl(end, prob_old, prob_ref, reward):
    #prob_old -> [4, gen_lens-1]
    #prob_ref -> [4, gen_lens-1]
    #reward -> [4]

    #两份预测概率求kl散度
    #[4, gen_lens-1]
    reward_kl = -0.1 * (prob_old - prob_ref)

    #把原本的reward加在kl散度的最后一个字上
    for i, e in enumerate(end):
        if e >= reward_kl.shape[1]:
            e = -1
        reward_kl[i, e] += reward[i].clamp(-5, 5)

    #[4, gen_lens-1]
    return reward_kl

prob_old为Actor模型产生的结果概率。prob_ref为Ref模型产生结果的概率。

Delta:

def get_delta(value_old, reward_kl):
    #value_old -> [4, gen_lens-1]
    #reward_kl -> [4, gen_lens-1]

    #gen_lens-2 -> 255
    delta = []
    for i in reversed(range(255, value_old.shape[1])):
        #[4]
        value_next = 0.0
        if i != value_old.shape[1] - 1:
            value_next = value_old[:, i + 1]

        #[4]
        d = reward_kl[:, i] + value_next - value_old[:, i]
        if len(delta):
            d += 0.95 * delta[-1]
        delta.append(d)

    #[4, gen_lens-256]
    delta = torch.stack(delta[::-1], dim=1)

    return delta

此时gamma为0.95.

重要性分布Ratio:

def get_loss_actor(prob_new, prob_old, delta, generate_mask):
    prob_new = prob_new[:, 255:]
    prob_old = prob_old[:, 255:]
    generate_mask = generate_mask[:, 256:]

    #prob_new -> [4, gen_lens-256]
    #prob_old -> [4, gen_lens-256]
    #delta -> [4, gen_lens-256]
    #generate_mask -> [4, gen_lens-256]

    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率
    #[4, gen_lens-256]
    ratio = ((prob_new - prob_old) * generate_mask).exp()

    #delta是估计出来的去基线Q值,以变化率来缩放Q值
    #最大化Q值,以此来寻找最优的actor
    #裁剪,防止自举
    #[4, gen_lens-256]
    loss1 = delta * ratio
    loss2 = delta * ratio.clamp(0.8, 1.2)
    loss = torch.min(loss1, loss2) * generate_mask
    loss = loss.sum() / generate_mask.sum() / 8
    return -loss

2.2 Critic loss

最初的critic loss如下,为即时总收益和预估总收益的mse loss。

为了使下一个时刻的即时总收益与上一时刻的即时总收益不相差太大,使用上一时刻的即时总收益进行约束,实际代码如下:

def get_loss_critic(value_new, value_old, delta, generate_mask):
    value_new = value_new[:, 255:]
    value_old = value_old[:, 255:]
    generate_mask = generate_mask[:, 256:]

    #value_new -> [4, gen_lens-256]
    #value_old -> [4, gen_lens-256]
    #delta -> [4, gen_lens-256]
    #generate_mask -> [4, gen_lens-256]

    #delta是估计出来的去基线Q值,加上value_old后还原为Q值
    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计
    #裁剪,防止自举
    #[4, gen_lens-256]
    loss1 = (value_new - delta - value_old)**2
    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
    loss2 = (value_new - delta - value_old)**2

    #求平均
    loss = torch.max(loss1, loss2) * generate_mask
    loss = loss.sum() / 2 / generate_mask.sum() / 8

    return loss

Related Post