-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Token-level rewards #1302
Token-level rewards #1302
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -605,14 +605,34 @@ def _step_safety_checker( | |
scores = [tensor.to(self.current_device) for tensor in scores] | ||
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None | ||
|
||
# squeeze scores if needed | ||
for i, score in enumerate(scores): | ||
if score.dim() > 1: | ||
# format scores to token-level scores if needed | ||
for i, (score, response) in enumerate(zip(scores, responses)): | ||
# make score 1-dimensional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment seems outdated. |
||
if score.dim() > 2: | ||
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}") | ||
elif score.dim() == 1: | ||
scores[i] = score.squeeze() | ||
elif score.dim() == 2: | ||
if score.shape[0] != 1 or score.shape[1] != 1: | ||
raise ValueError(f"Scores must be 1-dimensional - got {score.shape} for {score}") | ||
else: | ||
score = score.squeeze(1) | ||
elif score.dim() == 0: | ||
score = score.unsqueeze(0) | ||
# make score token-level | ||
if score.shape[0] != 1: | ||
if score.shape[0] != response.shape[0]: | ||
raise ValueError( | ||
f"Score and response must have the same length if score not scalar- got {score.shape[0]} and {response.shape[0]}" | ||
) | ||
else: | ||
scores[i] = score | ||
token_level_score = True | ||
else: | ||
token_score = torch.zeros_like(response, dtype=float).squeeze().to(self.current_device) | ||
token_score[-1] = score | ||
scores[i] = token_score | ||
token_level_score = False | ||
Comment on lines
+608
to
+633
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like way too many if-else statements. The end-of-sequence reward is a special case of token-level rewards. Maybe we could just always convert end-of-sequence reward to token-level rewards internally? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vwxyzjn happy to change things but let me just quickly explain what I'm doing a bit here in a bit more detail to see if it makes sense first, as it's doing two things which maybe makes the number of if/else statements a bit more palatable
I added the token_level_score flag as well to inform later manipulations of the score which I think should depend on which one got passed to trainer.step() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a |
||
|
||
return queries, responses, scores, masks | ||
return queries, responses, scores, masks, token_level_score | ||
|
||
@PPODecorators.empty_device_cache() | ||
def step( | ||
|
@@ -640,29 +660,42 @@ def step( | |
""" | ||
bs = self.config.batch_size | ||
|
||
queries, responses, scores, response_masks = self._step_safety_checker( | ||
queries, responses, scores, response_masks, token_level_score = self._step_safety_checker( | ||
bs, queries, responses, scores, response_masks | ||
) | ||
scores = torch.tensor(scores, device=self.current_device) | ||
|
||
# we pad to one tensor to better handle scaling and clipping. | ||
# different pad values as a token level score of 0 should be ignored | ||
# if step was called with a scalar score, but not if token level | ||
padding_value = float("-inf") if token_level_score else 0 | ||
max_length = max(score.size(0) for score in scores) | ||
padded_scores = torch.stack( | ||
[F.pad(score, (0, max_length - score.size(0)), value=padding_value) for score in scores] | ||
) | ||
|
||
if self.config.use_score_scaling: | ||
# Score scaling | ||
scores_mean, scores_std = self.running.update(scores) | ||
tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) | ||
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps | ||
scores_mean, scores_std = self.running.update(padded_scores[padded_scores != padding_value]) | ||
tensor_to_kwargs = dict(dtype=padded_scores.dtype, device=padded_scores.device) | ||
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(padded_scores.dtype).eps | ||
if self.config.use_score_norm: | ||
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor | ||
padded_scores[padded_scores != padding_value] = ( | ||
padded_scores[padded_scores != padding_value] - self.running.mean.to(**tensor_to_kwargs) | ||
) / score_scaling_factor | ||
else: | ||
scores /= score_scaling_factor | ||
padded_scores /= score_scaling_factor | ||
|
||
if self.config.score_clip is not None: | ||
# Score clipping | ||
scores_dtype = scores.dtype | ||
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) | ||
scores_dtype = padded_scores.dtype | ||
padded_scores = torch.clip(padded_scores.float(), -self.config.score_clip, self.config.score_clip).to( | ||
dtype=scores_dtype | ||
) | ||
|
||
# if we want to push best model to the hub | ||
if hasattr(self, "highest_reward"): | ||
if self.compare_step % self.config.compare_steps == 0: | ||
curr_mean_reward = scores.mean() | ||
curr_mean_reward = padded_scores[padded_scores != padding_value].sum() / bs | ||
# if the best reward ever seen | ||
if curr_mean_reward > self.highest_reward: | ||
self.highest_reward = curr_mean_reward | ||
|
@@ -734,10 +767,10 @@ def step( | |
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) | ||
|
||
rewards, non_score_reward, kls = self.compute_rewards( | ||
scores, active_full_logprobs, ref_full_logprobs, masks | ||
padded_scores, active_full_logprobs, ref_full_logprobs, masks | ||
) | ||
else: | ||
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) | ||
rewards, non_score_reward, kls = self.compute_rewards(padded_scores, all_logprobs, ref_logprobs, masks) | ||
timing["time/ppo/compute_rewards"] = time.time() - t | ||
|
||
t = time.time() | ||
|
@@ -821,8 +854,12 @@ def step( | |
train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) | ||
train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) | ||
|
||
total_scores = torch.tensor([score[score != padding_value].sum() for score in padded_scores]).to( | ||
self.current_device | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not only sum it up — the discount factor gamma also comes in. The total_scores should be the discounted episodic return. See https://spinningup.openai.com/en/latest/spinningup/rl_intro.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vwxyzjn ah yes good point, I've added that calculation |
||
stats = self.record_step_stats( | ||
scores=scores, | ||
scores=total_scores, | ||
logprobs=all_logprobs, | ||
ref_logprobs=ref_logprobs, | ||
non_score_reward=non_score_reward, | ||
|
@@ -1087,7 +1124,7 @@ def compute_rewards( | |
|
||
Args: | ||
scores (`torch.FloatTensor`): | ||
Scores from the reward model, shape (`batch_size`) | ||
Scores from the reward model, shape (`batch_size`, `max_response_length') | ||
logprobs (`torch.FloatTensor`): | ||
Log probabilities of the model, shape (`batch_size`, `response_length`) | ||
ref_logprobs (`torch.FloatTensor`): | ||
|
@@ -1106,10 +1143,11 @@ def compute_rewards( | |
non_score_reward = -self.kl_ctl.value * kl | ||
non_score_rewards.append(non_score_reward) | ||
reward = non_score_reward.clone() | ||
last_non_masked_index = mask.nonzero()[-1] | ||
|
||
# get the unpadded score | ||
score = score[: mask.sum()] | ||
# reward is preference model score + KL penalty | ||
reward[last_non_masked_index] += score | ||
reward[mask.bool()] += score | ||
rewards.append(reward) | ||
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls) | ||
|
||
|
@@ -1330,6 +1368,9 @@ def log_stats( | |
""" | ||
|
||
# all gather stats | ||
|
||
# sum to episodic rewards | ||
rewards = [reward.sum() for reward in rewards] | ||
Comment on lines
+1376
to
+1377
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not the discounted return |
||
if not isinstance(rewards, torch.Tensor): | ||
rewards = torch.tensor(rewards).to(self.current_device) | ||
rewards = self.accelerator.gather(rewards).flatten() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a breaking change requiring users to always pass in token level rewards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vwxyzjn I don't think this is a breaking change, but only because I'm assuming that no one is calling self.compute_rewards() outside of self.step(), as I couldn't see why that would be useful.
I set it up so you can pass either token-level or end-of-sequence rewards to step(), which then all get converted to token-level in _step_safety_checker(), meaning everything that gets sent to compute_rewards() would be always token-level.