Skip to content

Commit dab77b4

Browse files
authored
[feat] REAL Loss (Rewards as Labels) for GRPO Training (#8424)
1 parent d044111 commit dab77b4

8 files changed

Lines changed: 259 additions & 1 deletion

File tree

docs/resources/real.png

85.4 KB
Loading

docs/resources/real_framework.png

49.6 KB
Loading
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Rewards as Labels: Revisiting RLVR from a Classification Perspective
2+
3+
**版本依赖**:ms-swift>4.0
4+
5+
[Rewards as Labels: Revisiting RLVR from a Classification Perspective](https://arxiv.org/abs/2602.05630) 针对GRPO提出把奖励视为标签,在group内分类而不是计算advantage,从而将策略优化问题转化为分类问题,以此解决GRPO Loss中存在的正样本**梯度错配**与负样本**梯度主导**问题。
6+
7+
## 背景与动机
8+
9+
GRPO目标函数
10+
11+
$$
12+
J_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q,o\sim\pi_{\mathrm{od}}(\cdot|q)}\left[\frac{1}{|o|}\sum_{t=1}^{|o|}\left(\min\left(\rho_tA_t,\mathrm{clip}(\rho_t,1-\epsilon,1+\epsilon)A_t\right)\right)\right]
13+
$$
14+
15+
其中$\rho_t=\frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$为相对概率,$A_{t}$为优势函数,故梯度为:
16+
17+
$$
18+
\nabla_{\theta} J_{\mathrm{GRPO}} = \mathbb { E } \left[ \frac { 1 } { | o | } \sum _ { t = 1 } ^ { | o | } \mathbb { I } _ { \mathrm { clip } } \cdot A _ { t } e ^ { s _ { t } } \nabla _ { \theta } \log \pi _ { \theta } \left( o _ { t } | q \right) \right]
19+
$$
20+
21+
其中$s_t=\log\frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$作为token的相对对数概率,$\mathbb { I } _ { \mathrm { clip } }$为指示函数
22+
23+
故 GRPO 对单 token 的梯度权重为:
24+
25+
$$
26+
|\mathcal{W}_{\mathrm{GRPO}}|=\left\{ \begin{array} {ll}\left|A\cdot e^s\right|, & \mathrm{if~}\mathbb{I}_{\mathrm{clip}}=1, \\ 0, & \text{otherwise.} \end{array}\right.
27+
$$
28+
29+
![Gradient magnitude visualizations in GRPO](../../../../resources/real.png)
30+
31+
- 正样本的梯度错配(Gradient Misassignment):对正样本来说,随着相对概率$s$变小,梯度更新幅度反而越弱。这违背直觉,因为模型对“不太自信”的正确 token 本来就需要更大的更新幅度来强化,但更多的梯度权重却放到更“自信”的 token,没学好的 token 得不到足够的重视。
32+
33+
- 负样本的梯度主导(Gradient Domination):对负样本来说,随着相对概率$s$变小,梯度更新幅度呈指数级增加。这意味着,只要出现几个模型“盲目自信”的错误 token,它们产生的巨大梯度就会把同组内其他负样本的信号淹没。由于缺乏上限保护,模型在处理这些错误样本时可能会产生过大的参数更新,让训练过程变得不太可控。
34+
35+
为解决上述问题,Real提出将奖励直接视为标签然后进行组内的样本分类训练
36+
37+
![Real Framework](../../../../resources/real_framework.png)
38+
39+
分类的logits分值设计:
40+
41+
$$
42+
\bar{s}^k=\frac{1}{|o^k|}\sum_{t=1}^{|o^k|}\left(\log\frac{\pi_\theta(o_t^k\mid q)}{\pi_{\mathrm{old}}(o_t^k\mid q)}\right)
43+
$$
44+
45+
- $\bar{s}^k > 0$: 表示该样本在当前策略下生成的概率比旧策略整体更高,模型倾向于**增强**该样本。
46+
- $\bar{s}^k < 0$: 表示该样本在当前策略下生成的概率比旧策略整体更低,模型倾向于**抑制**该样本。
47+
48+
损失函数设计:
49+
50+
$$
51+
\mathcal{L}_{REAL}=\log\left(1+\sum_{\mathcal{O}_+}e^{-\bar{s}^i/\tau}\right)+\log\left(1+\sum_{\mathcal{O}_-}e^{\bar{s}^j/\tau}\right)
52+
$$
53+
54+
梯度特性:
55+
$$
56+
|\mathcal{W}_{\mathrm{REAL}}|=
57+
\begin{cases}
58+
\frac{1}{\tau}\frac{1}{1+C_{+}e^{\bar{s}^{k}/\tau}}, & r=1 \\
59+
\\
60+
\frac{1}{\tau}\frac{1}{1+C_{-}e^{-\bar{s}^{k}/\tau}}, & r=0 & & &
61+
\end{cases}
62+
$$
63+
64+
## 参数设置
65+
66+
| 参数 | 类型 | 默认值 | 说明 |
67+
|-------------------|---------|-------|-------------------|
68+
| `--loss_type` | `str` | - | 设置为 `real` |
69+
| `--real_tau` | `float` | `0.5` | 温度参数,控制决策边界锐度 |
70+
71+
训练脚本参考
72+
73+
[swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/real.sh)
74+
75+
## 注意事项
76+
77+
设置参数时,确保 per_device_train_batch_size 能够被 num_generations 整除,以此保证单个训练batch中能拿到完整的 group 进行分类。
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Rewards as Labels: Revisiting RLVR from a Classification Perspective
2+
3+
**Version Requirement**:ms-swift>4.0
4+
5+
[Rewards as Labels: Revisiting RLVR from a Classification Perspective](https://arxiv.org/abs/2602.05630) proposes a reformulation of GRPO by treating rewards as labels and performing **in-group classification** instead of advantage estimation. This converts the policy optimization problem into a classification problem, thereby addressing two key issues in the GRPO loss:
6+
- **Gradient Misassignment** for positive samples
7+
- **Gradient Domination** for negative samples
8+
9+
## Background and Motivation
10+
11+
GRPO Objective
12+
13+
$$
14+
J_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q,o\sim\pi_{\mathrm{od}}(\cdot|q)}\left[\frac{1}{|o|}\sum_{t=1}^{|o|}\left(\min\left(\rho_tA_t,\mathrm{clip}(\rho_t,1-\epsilon,1+\epsilon)A_t\right)\right)\right]
15+
$$
16+
17+
where:
18+
- $\rho_t = \frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$ is the probability ratio
19+
- $A_t$ is the advantage function
20+
21+
The corresponding gradient is:
22+
23+
$$
24+
\nabla_{\theta} J_{\mathrm{GRPO}} = \mathbb { E } \left[ \frac { 1 } { | o | } \sum _ { t = 1 } ^ { | o | } \mathbb { I } _ { \mathrm { clip } } \cdot A _ { t } e ^ { s _ { t } } \nabla _ { \theta } \log \pi _ { \theta } \left( o _ { t } | q \right) \right]
25+
$$
26+
27+
where:
28+
- $s_t = \log \frac{\pi_\theta(o_t|q)}{\pi_{\mathrm{old}}(o_t|q)}$ is the relative log-probability
29+
- $\mathbb{I}_{\mathrm{clip}}$ is the clipping indicator
30+
31+
Thus, the per-token gradient weight in GRPO is:
32+
33+
$$
34+
|\mathcal{W}_{\mathrm{GRPO}}|=\left\{ \begin{array} {ll}\left|A\cdot e^s\right|, & \mathrm{if~}\mathbb{I}_{\mathrm{clip}}=1, \\ 0, & \text{otherwise.} \end{array}\right.
35+
$$
36+
37+
![Gradient magnitude visualizations in GRPO](../../../../resources/real.png)
38+
39+
1. **Gradient Misassignment (Positive Samples)**
40+
For positive samples, as the relative log-probability $s$ decreases, the gradient magnitude also decreases.
41+
This is counterintuitive: tokens that the model is less confident about but correct should receive larger updates. However, GRPO assigns more weight to already confident tokens, causing under-trained tokens to receive insufficient learning signal.
42+
43+
2. **Gradient Domination (Negative Samples)**
44+
For negative samples, as $s$ decreases, the gradient magnitude increases exponentially.
45+
This leads to a situation where a few overconfident incorrect tokens dominate the gradient, overwhelming other negative signals within the same group. Due to the absence of an upper bound, this may result in unstable and excessively large parameter updates.
46+
47+
To address the above issues, REAL treats rewards directly as labels and performs **group-wise classification training**.
48+
49+
![Real Framework](../../../../resources/real_framework.png)
50+
51+
The classification logit for each sample is defined as:
52+
53+
$$
54+
\bar{s}^k=\frac{1}{|o^k|}\sum_{t=1}^{|o^k|}\left(\log\frac{\pi_\theta(o_t^k\mid q)}{\pi_{\mathrm{old}}(o_t^k\mid q)}\right)
55+
$$
56+
57+
- $\bar{s}^k > 0$: The sample is more likely under the current policy than the old policy → the model tends to **promote** this sample
58+
- $\bar{s}^k < 0$: The sample is less likely under the current policy → the model tends to **suppress** this sample
59+
60+
Loss Function
61+
62+
$$
63+
\mathcal{L}_{REAL}=\log\left(1+\sum_{\mathcal{O}_+}e^{-\bar{s}^i/\tau}\right)+\log\left(1+\sum_{\mathcal{O}_-}e^{\bar{s}^j/\tau}\right)
64+
$$
65+
66+
Gradient Properties
67+
68+
$$
69+
|\mathcal{W}_{\mathrm{REAL}}|=
70+
\begin{cases}
71+
\frac{1}{\tau}\frac{1}{1+C_{+}e^{\bar{s}^{k}/\tau}}, & r=1 \\
72+
\\
73+
\frac{1}{\tau}\frac{1}{1+C_{-}e^{-\bar{s}^{k}/\tau}}, & r=0 & & &
74+
\end{cases}
75+
$$
76+
77+
## Parameter Settings
78+
79+
| Parameter | Type | Default | Description |
80+
|-----------|------|---------|--------------------------------------------------------------------|
81+
| `--loss_type` | `str` | - | Set to `real` |
82+
| `--real_tau` | `float` | `0.5` | Temperature parameter controlling decision boundary sharpness |
83+
84+
Training Script Reference
85+
86+
[swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/real.sh)
87+
88+
## Important Notes
89+
90+
When configuring training parameters, ensure that:
91+
- `per_device_train_batch_size` is divisible by `num_generations`
92+
93+
This guarantees that each training batch contains complete groups, which is required for correct in-group classification.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
CUDA_VISIBLE_DEVICES=2 \
2+
swift rollout \
3+
--model Qwen/Qwen3-1.7B
4+
5+
6+
NPROC_PER_NODE=2 \
7+
CUDA_VISIBLE_DEVICES=0,1 \
8+
swift rlhf \
9+
--rlhf_type grpo \
10+
--model Qwen/Qwen3-1.7B \
11+
--dataset 'AI-MO/NuminaMath-TIR#5000' \
12+
--enable_thinking false \
13+
--reward_funcs accuracy \
14+
--use_vllm true \
15+
--vllm_mode server \
16+
--vllm_server_host 127.0.0.1 \
17+
--vllm_server_port 8000 \
18+
--tuner_type full \
19+
--torch_dtype bfloat16 \
20+
--load_from_cache_file true \
21+
--max_completion_length 4096 \
22+
--num_train_epochs 1 \
23+
--per_device_train_batch_size 8 \
24+
--learning_rate 2e-6 \
25+
--gradient_accumulation_steps 1 \
26+
--save_total_limit 2 \
27+
--save_steps 500 \
28+
--logging_steps 1 \
29+
--warmup_ratio 0.05 \
30+
--dataloader_num_workers 4 \
31+
--num_generations 8 \
32+
--temperature 0.6 \
33+
--system """You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.""" \
34+
--log_completions true \
35+
--num_iterations 1 \
36+
--beta 0.001 \
37+
--loss_type real \
38+
--deepspeed zero2

swift/arguments/rlhf_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,13 @@ def _init_grpo(self):
369369
else:
370370
raise ValueError(f'Invalid advantage_estimator: {self.advantage_estimator}')
371371

372+
# disable normalization, REAL https://arxiv.org/abs/2602.05630
373+
if self.loss_type == 'real':
374+
self.scale_rewards = 'none'
375+
logger.warning(
376+
f"[REAL] scale_rewards='{self.scale_rewards}' is ignored. "
377+
"It will be forced to 'none' because 'loss_type = real' does not support reward normalization.")
378+
372379
if self.scale_rewards is None:
373380
if self.advantage_estimator == 'grpo':
374381
self.scale_rewards = 'group'

swift/rlhf_trainers/args_mixin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
267267
tau_neg (float): The temperature parameter for negative dominance in the SAPO algorithm, controlling the
268268
sharpness of the soft gating function. Typically, `tau_neg` is set > `tau_pos` to impose stronger
269269
constraints on negative dominance. The default value is 1.05.
270+
real_tau (float): The temperature parameter. REAL induces monotonic and bounded gradient weighting with
271+
magnitude upper-bounded by 1/tau. The default value is 0.5.
270272
advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation
271273
function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo',
272274
'reinforce_plus_plus'. Defaults to 'grpo'.
@@ -362,6 +364,9 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
362364
# If false, add KL into loss, otherwise add into reward
363365
kl_in_reward: Optional[bool] = None # rloo/reinforce_plus_plus: true, grpo: false (default)
364366

367+
# REAL https://arxiv.org/abs/2602.05630
368+
real_tau: float = 0.5
369+
365370
generation_batch_size: Optional[int] = None
366371
steps_per_generation: Optional[int] = None
367372
num_generations_eval: Optional[int] = None

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,8 @@ def _compute_loss_and_metrics(self, model, inputs):
12021202
soft_gate = torch.where(is_positive, gate_pos, gate_neg)
12031203

12041204
per_token_loss = -soft_gate * advantages_expanded
1205+
elif self.loss_type == 'real':
1206+
per_token_loss = torch.zeros_like(per_token_logps)
12051207
elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']:
12061208
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
12071209
if self.args.delta is not None:
@@ -1240,6 +1242,39 @@ def _compute_loss_and_metrics(self, model, inputs):
12401242
elif self.loss_type == 'dr_grpo':
12411243
batch_size = completion_mask.shape[0]
12421244
loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length)
1245+
elif self.loss_type == 'real':
1246+
global_scores = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
1247+
1248+
group_scores = global_scores.view(-1, self.num_generations)
1249+
group_rewards = advantages.view(-1, self.num_generations)
1250+
1251+
pos_mask = (group_rewards > 0)
1252+
neg_mask = (group_rewards <= 0)
1253+
valid_mask = (pos_mask.sum(dim=1) != 0) & (neg_mask.sum(dim=1) != 0)
1254+
1255+
if not valid_mask.any():
1256+
loss = torch.tensor(0., device=global_scores.device) * global_scores.mean()
1257+
else:
1258+
batch_scores = group_scores[valid_mask]
1259+
batch_pos_mask = pos_mask[valid_mask]
1260+
batch_neg_mask = neg_mask[valid_mask]
1261+
1262+
scaled_scores = batch_scores / self.real_tau
1263+
zeros = torch.zeros(batch_scores.size(0), 1, device=batch_scores.device, dtype=batch_scores.dtype)
1264+
1265+
# Negative Loss: log(1 + sum(e^{S_neg}))
1266+
neg_input = scaled_scores.masked_fill(~batch_neg_mask, float('-inf'))
1267+
neg_loss = torch.logsumexp(torch.cat([neg_input, zeros], dim=1), dim=1)
1268+
1269+
# Positive Loss: log(1 + sum(e^{-S_pos}))
1270+
pos_input = (-scaled_scores).masked_fill(~batch_pos_mask, float('-inf'))
1271+
pos_loss = torch.logsumexp(torch.cat([pos_input, zeros], dim=1), dim=1)
1272+
1273+
loss = (neg_loss + pos_loss).sum() / group_rewards.size(0)
1274+
1275+
if self.beta != 0.0:
1276+
kl_loss = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1277+
loss = loss + kl_loss * self.beta
12431278
elif self.loss_type in ['cispo', 'dapo']:
12441279
# CISPO and DAPO: Normalize by total completion tokens across all processes
12451280
normalizer = inputs['num_items_in_batch'] / self.accelerator.num_processes
@@ -1279,7 +1314,7 @@ def masked_batch_mean(x):
12791314
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
12801315
gathered_cispo_clip_ratio = self.accelerator.gather_for_metrics(cispo_clip_ratio)
12811316
metrics_data['clipping'] = {'cispo_clip_ratio': gathered_cispo_clip_ratio.nanmean().item()}
1282-
elif self.loss_type == 'sapo':
1317+
elif self.loss_type in ['sapo', 'real']:
12831318
pass
12841319
else:
12851320
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
@@ -2183,6 +2218,9 @@ def _prepare_algorithm_params(self):
21832218
self.tau_pos = args.tau_pos
21842219
self.tau_neg = args.tau_neg
21852220

2221+
# REAL, https://arxiv.org/abs/2602.05630
2222+
self.real_tau = args.real_tau
2223+
21862224
# RLOO,
21872225
self.advantage_estimator = args.advantage_estimator
21882226
self.kl_in_reward = args.kl_in_reward

0 commit comments

Comments
 (0)