登录
首页 >  科技周边 >  人工智能

强化学习十二:GRPO、DAPO、DUPO、GSPO解析

时间:2025-12-08 21:21:53 113浏览 收藏

推广推荐
免费电影APP ➜
支持 PC / 移动端,安全直达

大家好,今天本人给大家带来文章《强化学习十二:GRPO、DAPO、DUPO、GSPO详解》,文中内容主要涉及到,如果你对科技周边方面的知识点感兴趣,那就请各位朋友继续看下去吧~希望能真正帮到你们,谢谢!

在之前的强化学习系列中我们介绍了强化学习的基础知识,也在系列十和系列十一中介绍了强化学习RL在LLM中的应用。

最近我在介绍DeepResearch Agent的论文分享中讨论过从高质量数据合成,Agentic增量预训练(CPT),有监督微调(SFT)冷启动,到强化学习(RL)全流程的方法。但是介绍过程中重点在数据和论文方案思路框架上,RL算法部分都略过了。因为我发现每篇论文都在使用不同的RL方法,每个都详细介绍篇幅太长,不如将这些RL方法单独做一篇详细聊聊。

PPO在LLM的应用就不用再介绍了,系列十已经聊过,所以本文就介绍一些PPO的优化方案。

GRPO:Group Relative Policy Optimization

DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models

论文链接:https://arxiv.org/pdf/2402.03300

论文主要介绍的DeepSeekMath 7B的模型,在 MATH 和 GSM8K 等基准测试上超越了参数量更大的闭源模型。而提升的原因除了精心设计的数据工程,另一个重要原因就是引入PPO的优化算法GRPO(Group Relative Policy Optimization ),这种方法不仅提升了数学推理能力,也显著降低了训练过程中的显存占用。

强化学习系列(十二)--GRPO,DAPO,DUPO,GSPO
descript

问题和背景

前文介绍过PPO框架中,我们需要有三个模型:

Actor策略模型:负责实际生成内容,是被优化的策略模型。

Reward Model奖励模型:负责给模型生成的完整输出打分,提供训练时的奖励信号。

Critic评论模型:负责估计当前生成状态的价值,用来计算优势,稳定训练过程。

强化学习系列(十二)--GRPO,DAPO,DUPO,GSPO
descript

上面就是PPO的优化目标函数。

PPO 的问题在于:Actor 和 Critic 都是大型神经网络,二者需要同时更新,显存开销大;并且 Advantage 的计算依赖一个学习到的价值函数,训练容易不稳定。

论文方案

GRPO的解决方案就是直接去掉了Critic网络。GRPO的核心创新思路很简单:直接剔除Critic,启用群体相对优势。PPO中Critic存在主要就是为了计算优势函数,而什么是优势呢?就需要用Critic模型的值和奖励函数的值作比较得到,也就是它需要一个base基线。PPO用了模型,GRPO认为不就是要base么,我每次取一批数据,将平均值作为base不就好了。所以,它通过在一个批次(Group)的样本中对奖励进行归一化来计算优势函数。

计算公式: 优势函数 \hat{A}_{i,t} ,被计算为当前奖励与组内平均奖励之差,再除以该组奖励的标准差:

\hat{A}_{i,t} = \widetilde{r}_i = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}

核心代码:

import torchimport torch.nn.functional as Fdef compute_sequence_likelihood(model, query, response):    """    计算序列似然度 π_θ(y | x)    """    input_ids = torch.cat([query, response], dim=1)    outputs = model(input_ids=input_ids)    logits = outputs.logits        # 获取 response 部分的 logits    response_log_probs = F.log_softmax(logits, dim=-1)[:, query.size(1)-1:-1, :]    token_log_probs = torch.gather(        response_log_probs, dim=2, index=response.unsqueeze(-1)    ).squeeze(-1)        # 序列级别的对数似然度    return token_log_probs.sum(dim=1)def compute_importance_ratio(model, query, response, old_log_probs):    """    计算序列级别的重要性比率 s_i(θ) = π_θ(y_i | x) / π_θ_old(y_i | x)    """    current_log_probs = compute_sequence_likelihood(model, query, response)    return torch.exp(current_log_probs - old_log_probs)def compute_normalized_advantages(rewards):    """    计算归一化优势 Ã_i = (r_i - mean(r)) / std(r)    """    mean_reward = rewards.mean(dim=-1, keepdim=True)    std_reward = rewards.std(dim=-1, keepdim=True).clamp(min=1e-8)    return (rewards - mean_reward) / std_rewarddef gspo_loss(model, queries, responses, old_log_probs, rewards, clip_range=0.2):    """    GSPO 核心损失函数    J = E[min(s_i * A_i, clip(s_i, 1-ε, 1+ε) * A_i)]    """    # 计算重要性比率    importance_ratios = []    for i, response in enumerate(responses):        old_log_prob = old_log_probs[:, i] if old_log_probs.dim() > 1 else old_log_probs        ratio = compute_importance_ratio(model, queries, response, old_log_prob)        importance_ratios.append(ratio)    importance_ratios = torch.stack(importance_ratios, dim=1)  # [batch, group_size]        # 计算归一化优势    advantages = compute_normalized_advantages(rewards)  # [batch, group_size]        # 序列级别裁剪    clipped_ratios = torch.clamp(importance_ratios, 1.0 - clip_range, 1.0 + clip_range)        objective = torch.min(        importance_ratios * advantages,        clipped_ratios * advantages    )        return -objective.mean()

总结

今天关于《强化学习十二:GRPO、DAPO、DUPO、GSPO解析》的内容就介绍到这里了,是不是学起来一目了然!想要了解更多关于强化学习的内容请关注golang学习网公众号!

相关阅读
更多>
最新阅读
更多>
课程推荐
更多>