VeRL 源码解读

参考文章:https://zhuanlan.zhihu.com/p/27676081245
ReMAX算法解读:https://zhuanlan.zhihu.com/p/662191782

1. verl.trainer.ppo.ray_trainer.py

  • apply_kl_penalty 函数,计算 PPO 的 token-level kl reward,对应:

1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
"""Apply KL penalty to the token-level rewards.

This function computes the KL divergence between the reference policy and current policy,
then applies a penalty to the token-level rewards based on this divergence.

Args:
data (DataProto): The data containing batched model outputs and inputs.
kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.
kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl".

Returns:
tuple: A tuple containing:
- The updated data with token-level rewards adjusted by KL penalty
- A dictionary of metrics related to the KL penalty
"""
# 标识哪些位置是模型生成的回答(1)而不是提示(0)
response_mask = data.batch["response_mask"]
token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0]

# compute kl between ref_policy and current policy
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value

token_level_rewards = token_level_scores - beta * kld

current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()

# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards

metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}

return data, metrics
  • RayPPOTrainer 类的 fit 函数,实现了 rl 算法的完整的 training loop,调用了各个 worker 进行实际的计算
  1. REINFORCE算法介绍
  • 核心思想:运行一个回合,如果总回报高,就增大这个回合中所有动作的概率;如果回报低,就减小概率。它直接利用整个回合的回报 $G_t$ (相当于 PPO 的优势)作为更新尺度。
  • 伪代码
1
2
3
4
5
6
7
8
9
初始化策略参数 θ
设置学习率 α
for episode = 1 to M do:
根据 π(·|s; θ) 采样一个轨迹: s0, a0, r1, s1, a1, r2, ... , sT
计算每个时间步 t 的回报 Gt
for t = 0 to T-1 do:
θ = θ + α * γ^t * Gt * ∇_θ log π_θ(a_t | s_t)
end for
end for
  1. ReMAX算法介绍
  • 核心思想:ReMax在REINFORCE的基础上使用贪婪生成的回答(greedy response,即 do_sample = False)的奖励作为基准值(baseline value),使用采样生成的回答(sampled response,即do_sample = True)的奖励减去基准值作为优势。
    伪代码
    2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
...

# 外层循环遍历所有 epoch
# 内层循环遍历训练数据加载器中的每个批次
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}

with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)

batch: DataProto = DataProto.from_single_dict(batch_dict)

# add uid to batch
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# 将 batch 字典中用于生成回答的 key pop 出来
gen_batch = self._get_gen_batch(batch)

# pass global_steps to trace
gen_batch.meta_info["global_steps"] = self.global_steps
# for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo
# 将 gen_batch 里的 prompt 重复 n 次
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

is_last_step = self.global_steps >= self.total_training_steps

with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
# 使用actor模型生成序列(同步或异步模式)
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

# 如果需要,计算REMAX优势,生成不采样的基线序列并计算其奖励
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
if self.reward_fn is None:
raise ValueError("A reward_fn is required for REMAX advantage estimation.")

with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
if not self.async_rollout_mode:
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
else:
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

batch.batch["reward_baselines"] = reward_baseline_tensor

del gen_baseline_batch, gen_baseline_output

# repeat to align with repeated responses in rollout
# 将 batch 字典中的剩余字段数量与 rollout 对齐
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if "response_mask" not in batch.batch.keys():
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)

# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
# 如果需要,使用奖励模型计算奖励;否则,使用规则奖励
if self.use_rm:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

# recompute old_log_probs
# 重新计算 old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
# 策略熵
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)

if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
from verl.utils.debug.metrics import calculate_debug_metrics

metrics.update(calculate_debug_metrics(batch))
# 如果需要,计算 ref 模型的 log_prob
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, color="olive"):
# if ref_in_actor is True, the reference policy will be actor without lora applied
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
# 如果需要,使用 critic 计算 values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
# 得到 token_level 奖励
batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

# compute rewards. apply_kl_penalty if available
# 如果需要,计算添加 kl 散度惩罚的 token_level 奖励
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# compute advantages, executed on the driver process
# 计算优势

norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)

# update critic
# 如果需要,更新 critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)

# implement critic warmup
# 在 critic 预热完成后,再更新 actor
# 在训练初期,先让 critic 单独训练一段时间,使其能够提供相对准确的价值估计
# 在训练开始时,critic 的预测可能非常不准确。如果立即用这些不准确的价值估计来更新 actor,可能会导致训练不稳定甚至发散
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
...
  • PPO 更新循环流程

3

  • DeepSpeed PPO 数据循环流程

4

2. verl.workers.actor.dp_actor.py

  • DataParallelPPOActor 类的 _forward_micro_batch 函数,计算每一个 micro_batch 数据对应的 entropy 和 log_probs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def _forward_micro_batch(
self, micro_batch, temperature, calculate_entropy=False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
...

else: # not using rmpad and no ulysses sp
extra_args = {}
if self.use_fused_kernels:
extra_args["temperature"] = temperature
extra_args["return_dict"] = True

output = self.actor_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
**extra_args,
) # prevent model thinks we are generating

if self.use_fused_kernels:
log_probs = output.log_probs[:, -response_length - 1 : -1]
entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length)

else:
logits = output.logits

logits.div_(temperature)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
# Compute per-token log-probabilities for the given labels.
# logp = F.log_softmax(logits, dim=-1)
# log_probs = gather_from_labels(logp, labels)
log_probs = logprobs_from_logits(logits, micro_batch["responses"]) # (bsz, response_length)
if calculate_entropy:
if not self.config.entropy_checkpointing:
# pd = torch.nn.functional.softmax(logits, dim=-1)
# entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
# H(P) = -∑ [ pd * (logits - log(Z)) ]
# = -∑ [ pd * logits ] + log(Z) * ∑ pd
# = -∑ [ pd * logits ] + log(Z) [因为∑ pd = 1]
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
else:
entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)

return entropy, log_probs
  • DataParallelPPOActor 类的 update_policy 函数,使用收集到的经验数据(experience),通过多轮次(epoch)的小批量(mini-batch)梯度下降来更新策略模型(Actor)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()

temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
# 使用 data.select() 方法从传入的 data 中筛选出训练所需的关键字段,避免不必要的数据传输和内存占用
select_keys = [
"responses",
"response_mask",
"input_ids",
"attention_mask",
"position_ids",
"old_log_probs",
"advantages",
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")

has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
# 创建小批量数据,用于多次更新
# 如果整个数据集就是一个 mini-batch,并且只训练一个 epoch,那么当前策略和收集数据时的策略是相同的,这就是典型的 on-policy 情形
mini_batches = data.split(self.config.ppo_mini_batch_size)
# 判断是否 on policy
on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1

metrics = {}
# 外层循环表示对整个数据集进行 self.config.ppo_epochs 次迭代
# 内层循环表示每次使用一个 mini_batch 的数据更新模型
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
# 将 mini_batch 进一步分割成 micro_batch
# 如果启用 use_dynamic_bsz,会根据 max_token_len(例如,4096个token)将 mini_batch 进一步动态地分割成 micro_batches,以确保每个micro-batch的token数量不超过限制。这有助于处理序列长度变化很大的情况,提高GPU内存利用率
# 如果未启用动态批次,则根据 self.config.ppo_micro_batch_size_per_gpu 将 mini_batch 均匀地分割成固定大小的 micro_batches。并计算梯度累积步数 gradient_accumulation
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.actor_optimizer.zero_grad()
# micro_batch 循环,设备层面切分数据
for micro_batch in micro_batches:
micro_batch = micro_batch.to(get_device_id())
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

if self.config.use_dynamic_bsz:
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
else:
loss_scale_factor = 1 / self.gradient_accumulation

# all return: (bsz, response_length)
calculate_entropy = False
if entropy_coeff != 0:
calculate_entropy = True
# 调用 _forward_micro_batch 方法,使用当前的策略模型(Actor)计算给定输入下,实际生成的response的log probabilities (log_prob) 和熵 (entropy)
entropy, log_prob = self._forward_micro_batch(
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
)

if on_policy:
old_log_prob = log_prob.detach()
else:
old_log_prob = model_inputs["old_log_probs"]

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
# 计算 pg_loss
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)
# 如果需要,计算 entropy_loss
if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
else:
policy_loss = pg_loss
# 如果需要,计算 kl_loss
if self.config.use_kl_loss:
ref_log_prob = model_inputs["ref_log_prob"]
# compute kl loss
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * loss_scale_factor
else:
loss = policy_loss * loss_scale_factor
# 反向传播,计算梯度
loss.backward()

micro_batch_metrics.update(
{
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}
)
append_to_dict(metrics, micro_batch_metrics)
# 更新模型参数
grad_norm = self._optimizer_step()
mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.actor_optimizer.zero_grad()
return metrics

3. verl.workers.critic.dp_critic.py

  • DataParallelPPOCritic 类的 _forward_micro_batch 函数,计算每一个 micro_batch 数据对应的 values
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def _forward_micro_batch(self, micro_batch):
response_length = micro_batch["responses"].size(-1)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)

with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1)

if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)

# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)

# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating

if hasattr(self.critic_module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values_rmpad = output[2].squeeze(0).unsqueeze(-1)
else:
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)

# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outputs_and_unpad(
values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)

# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1 : -1]
else:
output = self.critic_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
if hasattr(self.critic_module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values = output[2]
else:
values = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1)
return values
  • DataParallelPPOCritic 类的 update_critic 函数,使用收集到的经验数据(data),通过多轮次(epoch)的小批量(mini-batch)梯度下降来更新价值模型(Ctiric)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def update_critic(self, data: DataProto):
# make sure we are in training mode
self.critic_module.train()
metrics = {}
# 使用 data.select() 方法从传入的 data 中筛选出训练所需的关键字段,避免不必要的数据传输和内存占用
select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"]
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
# 创建小批量数据,用于多次更新
# 如果整个数据集就是一个 mini-batch,并且只训练一个 epoch,那么当前策略和收集数据时的策略是相同的,这就是典型的 on-policy 情形
mini_batches = data.split(self.config.ppo_mini_batch_size)
# 外层循环表示对整个数据集进行 self.config.ppo_epochs 次迭代
# 内层循环表示每次使用一个 mini_batch 的数据更新模型
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
# 将 mini_batch 进一步分割成 micro_batch
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.critic_optimizer.zero_grad()
# micro_batch 循环,设备层面切分数据
for micro_batch in micro_batches:
micro_batch = micro_batch.to(get_device_id())
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]

vpreds = self._forward_micro_batch(model_inputs)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
response_mask=response_mask,
cliprange_value=self.config.cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
loss = vf_loss * loss_scale_factor
else:
loss_scale_factor = 1 / self.gradient_accumulation
loss = vf_loss * loss_scale_factor

loss.backward()

micro_batch_metrics.update(
{
"critic/vf_loss": vf_loss.detach().item() * loss_scale_factor,
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
)

append_to_dict(metrics, micro_batch_metrics)

grad_norm = self._optimizer_step()
mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.critic_optimizer.zero_grad()
return metrics

4. trl.models.modeling_value_head.py

  • ValueHead 类定义了价值头为 nn.Linear(hidden_size, 1)
  • AutoModelForCausalLMWithValueHead 类的 forward 函数定义了带有价值头的模型的前向传播过程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class ValueHead(nn.Module):
r"""
The ValueHead class implements a head for GPT2 that returns a scalar for each output token.
"""

def __init__(self, config, **kwargs):
super().__init__()
if not hasattr(config, "summary_dropout_prob"):
summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
else:
summary_dropout_prob = config.summary_dropout_prob

self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()

# some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
if hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
if hasattr(config, "word_embed_proj_dim"):
hidden_size = config.word_embed_proj_dim
elif hasattr(config, "is_encoder_decoder"):
if config.is_encoder_decoder and hasattr(config, "decoder"):
if hasattr(config.decoder, "hidden_size"):
hidden_size = config.decoder.hidden_size

self.summary = nn.Linear(hidden_size, 1)

self.flatten = nn.Flatten()

def forward(self, hidden_states):
output = self.dropout(hidden_states)

# For now force upcast in fp32 if needed. Let's keep the
# output in fp32 for numerical stability.
if output.dtype != self.summary.weight.dtype:
output = output.to(self.summary.weight.dtype)

output = self.summary(output)
return output

class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
def __init__(self, pretrained_model, **kwargs):

super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
self._init_weights(**v_head_kwargs)

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):

kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
kwargs["past_key_values"] = past_key_values

if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
kwargs.pop("past_key_values")

base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)

last_hidden_state = base_model_output.hidden_states[-1]
lm_logits = base_model_output.logits
loss = base_model_output.loss

if last_hidden_state.device != self.v_head.summary.weight.device:
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)

value = self.v_head(last_hidden_state).squeeze(-1)

# force upcast in fp32 if logits are in half-precision
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()

if return_past_key_values:
return (lm_logits, loss, value, base_model_output.past_key_values)
else:
return (lm_logits, loss, value)

5

5. verl.workers.fsdp_workers.py

  • RewardModelWorker 类的 _forward_micro_batch 函数,计算奖励模型打分,提取最后一个有效 token 的奖励分数作为整个序列的奖励
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _forward_micro_batch(self, micro_batch):
if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import (
index_first_axis,
pad_input,
rearrange,
unpad_input,
)

from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs

with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)

if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)

# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)

# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(
input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False
)
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)

# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outputs_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)

# pad it back
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
else:
output = self.reward_module(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)

# extract the result of the last valid token
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
return rm_score

6. verl.trainer.ppo.core_algos.py

  • compute_gae_advantage_return 函数,计算 GAE 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)

"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]

for t in reversed(range(gen_len)):

# TD error
# δ_t = r_t + γ * V(s_{t+1}) - V(s_t)
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]

# GAE将 k 步优势估计进行指数加权平均
# A_t^GAE(γ,λ) = Σ_{l=0}^{∞} (γλ)^l δ_{t+l}
# 可以递归地表示为:A_t = δ_t + γλ * A_{t+1}
lastgaelam_ = delta + gamma * lam * lastgaelam

# skip values and TD-error on observation tokens
# 设观察token的掩码为m_t(0或1)
# V_{t+1} = m_t * V_t + (1 - m_t) * V_{t+1}
# A_{t+1} = m_t * A_t + (1 - m_t) * A_{t+1}
# 当m_t = 1(响应token):正常计算TD误差和优势
# 当m_t = 0(观察token):保持下一个状态的值和优势不变
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam

advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)

returns = advantages + values
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
  • compute_grpo_outcome_advantage 函数,计算 GRPO/Dr.GRPO 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).

Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object

Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}
id2std = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask

return scores, scores
  • compute_grpo_passk_outcome_advantage 函数,计算 Pass@k-GRPO 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for Pass@k using a GRPO-style outcome reward formulation.
Only the best response per group gets a non-zero advantage: r_max - r_second_max.

Implemented as described in https://arxiv.org/abs/2503.19595.

Args:
token_level_rewards: (bs, response_length)
response_mask: (bs, response_length)
index: (bs,) → group ID per sample
epsilon: float for numerical stability
config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo"

Returns:
advantages: (bs, response_length)
returns: (bs, response_length)
"""
assert config is not None
# if True, normalize advantage by std within group
norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True)
scores = token_level_rewards.sum(dim=-1) # (bs,)
advantages = torch.zeros_like(scores)

id2scores = defaultdict(list)
id2indices = defaultdict(list)

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx = index[i]
id2scores[idx].append(scores[i])
id2indices[idx].append(i)

for idx in id2scores:
rewards = torch.stack(id2scores[idx]) # (k,)
if rewards.numel() < 2:
raise ValueError(
f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}."
)
topk, topk_idx = torch.topk(rewards, 2)
r_max, r_second_max = topk[0], topk[1]
i_max = id2indices[idx][topk_idx[0].item()]
advantage = r_max - r_second_max
if norm_adv_by_std_in_grpo:
std = torch.std(rewards)
advantage = advantage / (std + epsilon)
advantages[i_max] = advantage

advantages = advantages.unsqueeze(-1) * response_mask
return advantages, advantages
  • compute_reinforce_plus_plus_baseline_outcome_advantage 函数,计算 REINFORCE++-baseline 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def compute_reinforce_plus_plus_baseline_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
(with only one scalar reward for each response).

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2mean[index[i]]

scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
scores = verl_F.masked_whiten(scores, response_mask) * response_mask

return scores, scores
  • compute_rloo_outcome_advantage 函数,计算 RLOO 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1) * response_mask

return scores, scores
  • compute_opo_outcome_advantage 函数,计算 OPO 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def compute_opo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = response_mask.sum(dim=-1)
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2len = defaultdict(list)
id2bsl = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
id2len[index[i]].append(response_length[i])

for idx in id2score:
if len(id2score[idx]) == 1:
id2bsl[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
score_tensor = torch.stack(id2score[idx])
len_tensor = torch.stack(id2len[idx])
id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2bsl[index[i]]
scores = scores.unsqueeze(-1) * response_mask

return scores, scores
  • compute_reinforce_plus_plus_outcome_advantage 函数,计算 REINFORCE++ 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
assert config is not None
gamma = config.gamma
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0

for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * response_mask[:, t]

advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask

return advantages, returns
  • compute_remax_outcome_advantage 函数,计算 ReMAX 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor,
reward_baselines: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""

with torch.no_grad():
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask

return advantages, returns
  • compute_gpg_outcome_advantage 函数,计算 GPG 优势
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def compute_gpg_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
f_norm: float = 1.0,
alpha: float = 1.0,
config=None,
**kwargs,
):
"""
Compute advantage for GPG, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
index: `(np.ndarray)`
shape: (bs,)
epsilon: (float)
f_norm: (float)
alpha: (float)
config: (dict) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}
id2std = {}

with torch.no_grad():
bsz = scores.shape[0]
m = torch.count_nonzero(scores)
alpha = bsz / m.clamp(min=1)

for i in range(bsz):
id2score[index[i]].append(scores[i])

for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
scores = scores.unsqueeze(-1) * response_mask

return scores, scores
  • compute_rewards 函数,计算 token 级别带 kl 散度惩罚的奖励
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
"""Compute token-level rewards with KL penalty.

Args:
token_level_scores (torch.Tensor): Token-level reward scores.
old_log_prob (torch.Tensor): Log probabilities from current policy.
ref_log_prob (torch.Tensor): Log probabilities from reference policy.
kl_ratio (float): KL penalty coefficient.

Returns:
torch.Tensor: Token-level rewards with KL penalty applied.
"""
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
  • agg_loss 函数,计算不同归一化方式的损失
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.

Args:
loss_mat: `(torch.Tensor)`:
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`:
shape: (bs, response_length)
loss_agg_mode: (str) choices:
method to aggregate the loss matrix into a scalar.
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

return loss
  • compute_policy_loss 函数,计算 policy_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def compute_policy_loss(
old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped policy objective and related metrics for PPO.

Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
clip_ratio_c (float, optional):
Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
Defaults to 3.0.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)

negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
  • compute_policy_loss_vanilla 函数,计算原生 policy_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.

Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""

assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)

cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high

assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)

negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower

VeRL 源码解读
https://cosmoliu2002.github.io/posts/verl-detail/
作者
LiuYu
发布于
2025年8月22日
许可协议