Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token-level rewards #1302

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def tearDown(self):
def _init_dummy_dataset(self):
# encode a query
query_txt = "This morning I went to the "
query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt")
query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt").to(
self.gpt2_model.pretrained_model.device
)
assert query_tensor.shape == (1, 7)
# get model response
response_tensor = respond_to_batch(self.gpt2_model, query_tensor)
Expand Down Expand Up @@ -457,14 +459,19 @@ def test_ppo_step_rewards_shape(self):
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])]
reward = [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([[0.0, 1.0]])]
# train model - this should raise an error
with self.assertRaises(ValueError):
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)

reward = [torch.tensor([1.0]), torch.tensor([0.0])]
# train model - this should work
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)

# token-level rewards
reward = [torch.tensor([1.0] * 7), torch.tensor([0.0] * 7)]
# train model - this should work
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break

# check if the gradients are computed for the model
Expand Down Expand Up @@ -498,7 +505,7 @@ def test_ppo_step_input_shape(self):
# train model - this should raise an error
bs = ppo_trainer.config.batch_size

queries, responses, _, _ = ppo_trainer._step_safety_checker(
queries, responses, _, _, _ = ppo_trainer._step_safety_checker(
bs, [q for q in query_tensor], [r for r in response_tensor], reward
)

Expand All @@ -516,7 +523,9 @@ def test_ppo_step_no_dataset(self):
Test if the training loop works fine without passing a dataset
"""
query_txt = "This morning I went to the "
query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt")
query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt").to(
self.gpt2_model.pretrained_model.device
)
self.ppo_config.batch_size = 1

response_tensor = respond_to_batch(self.gpt2_model, query_tensor)
Expand Down Expand Up @@ -565,7 +574,10 @@ def test_loss_trainer(self):

dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])]
dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])]
dummy_scores = torch.Tensor([1, 2])
dummy_scores = [
torch.tensor([0, 0, 0, 0, 1], device=ppo_trainer.current_device),
torch.tensor([0, 0, 0, 0, 0, 2], device=ppo_trainer.current_device),
]
Comment on lines +577 to +580
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a breaking change requiring users to always pass in token level rewards?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn I don't think this is a breaking change, but only because I'm assuming that no one is calling self.compute_rewards() outside of self.step(), as I couldn't see why that would be useful.

I set it up so you can pass either token-level or end-of-sequence rewards to step(), which then all get converted to token-level in _step_safety_checker(), meaning everything that gets sent to compute_rewards() would be always token-level.


ppo_trainer.config.mini_batch_size = 1
ppo_trainer.config.batch_size = 1
Expand Down Expand Up @@ -989,9 +1001,11 @@ def make_inputs_require_grad(module, input, output):
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break

new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs)
self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :]))
self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :]))
new_logits = ppo_trainer.model.compute_reward_score(
dummy_inputs.to(ppo_trainer.model.pretrained_model.device)
)
self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :].to(previous_rm_logits.device)))
self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :].to(original_rm_logits.device)))

# check gradients
for name, param in model.named_parameters():
Expand Down Expand Up @@ -1126,7 +1140,10 @@ def test_generation(self):

tokenizer.pad_token = tokenizer.eos_token

model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
model_inputs = [
tokenizer(txt, return_tensors="pt").to(model.pretrained_model.device).input_ids.squeeze()
for txt in input_texts
]

generations_batched = ppo_trainer.generate(model_inputs, batch_size=2, **generation_kwargs)
generations_batched = tokenizer.batch_decode(generations_batched)
Expand Down
89 changes: 67 additions & 22 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,34 @@ def _step_safety_checker(
scores = [tensor.to(self.current_device) for tensor in scores]
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None

# squeeze scores if needed
for i, score in enumerate(scores):
if score.dim() > 1:
# format scores to token-level scores if needed
for i, (score, response) in enumerate(zip(scores, responses)):
# make score 1-dimensional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems outdated.

if score.dim() > 2:
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
elif score.dim() == 1:
scores[i] = score.squeeze()
elif score.dim() == 2:
if score.shape[0] != 1 or score.shape[1] != 1:
raise ValueError(f"Scores must be 1-dimensional - got {score.shape} for {score}")
else:
score = score.squeeze(1)
elif score.dim() == 0:
score = score.unsqueeze(0)
# make score token-level
if score.shape[0] != 1:
if score.shape[0] != response.shape[0]:
raise ValueError(
f"Score and response must have the same length if score not scalar- got {score.shape[0]} and {response.shape[0]}"
)
else:
scores[i] = score
token_level_score = True
else:
token_score = torch.zeros_like(response, dtype=float).squeeze().to(self.current_device)
token_score[-1] = score
scores[i] = token_score
token_level_score = False
Comment on lines +608 to +633
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like way too many if-else statements. The end-of-sequence reward is a special case of token-level rewards. Maybe we could just always convert end-of-sequence reward to token-level rewards internally?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn happy to change things but let me just quickly explain what I'm doing a bit here in a bit more detail to see if it makes sense first, as it's doing two things which maybe makes the number of if/else statements a bit more palatable

  1. lines 611-619 are doing most of the safety checking because to make sure its back-compatitble we need to accept scores that are either scalar e.g. [1.0] or as a vector e.g. [1.0,2.0,3.0]. I think this means we need to do a lot more dim checks as in the first case I could expect either torch.tensor(1.0) or torch.tensor([1.0]) with dim 0,1 respectively from the user. Similarly with the second either torch.tensor([1.0,2.0,3.0]) or torch.tensor([[1.0,2.0,3.0]]) with dim 1,2 respectively. So here we're just making sure we deal with dim 1, i.e. torch.tensor([1.0]) or torch.tensor([1.0,2.0,3.0]). Theoretically could simplify the code here and just get rid of the logic for turning non dim 1 tensors to dim 1 and throw an error, but it seems helpful to convert - let me know if you think this is a better strategy.
  2. lines 621-633 is doing what you suggest which is then to convert end-of-sequence rewards to token-level rewards as we check the length and if it's 1 we turn it into a token-level reward with the score at the last spot, otherwise we just check the token-level reward is the sam size as the response and if it is then everything's good

I added the token_level_score flag as well to inform later manipulations of the score which I think should depend on which one got passed to trainer.step()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a token_level_score? Maybe we can just assume score is token-level. Also the token_level_score appears not tracked for each index, so if at i, token_level_score=True, and i+1, token_level_score=False. Then isn't this information missing?


return queries, responses, scores, masks
return queries, responses, scores, masks, token_level_score

@PPODecorators.empty_device_cache()
def step(
Expand Down Expand Up @@ -640,29 +660,42 @@ def step(
"""
bs = self.config.batch_size

queries, responses, scores, response_masks = self._step_safety_checker(
queries, responses, scores, response_masks, token_level_score = self._step_safety_checker(
bs, queries, responses, scores, response_masks
)
scores = torch.tensor(scores, device=self.current_device)

# we pad to one tensor to better handle scaling and clipping.
# different pad values as a token level score of 0 should be ignored
# if step was called with a scalar score, but not if token level
padding_value = float("-inf") if token_level_score else 0
max_length = max(score.size(0) for score in scores)
padded_scores = torch.stack(
[F.pad(score, (0, max_length - score.size(0)), value=padding_value) for score in scores]
)

if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
scores_mean, scores_std = self.running.update(padded_scores[padded_scores != padding_value])
tensor_to_kwargs = dict(dtype=padded_scores.dtype, device=padded_scores.device)
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(padded_scores.dtype).eps
if self.config.use_score_norm:
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
padded_scores[padded_scores != padding_value] = (
padded_scores[padded_scores != padding_value] - self.running.mean.to(**tensor_to_kwargs)
) / score_scaling_factor
else:
scores /= score_scaling_factor
padded_scores /= score_scaling_factor

if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
scores_dtype = padded_scores.dtype
padded_scores = torch.clip(padded_scores.float(), -self.config.score_clip, self.config.score_clip).to(
dtype=scores_dtype
)

# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = scores.mean()
curr_mean_reward = padded_scores[padded_scores != padding_value].sum() / bs
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
Expand Down Expand Up @@ -734,10 +767,10 @@ def step(
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)

rewards, non_score_reward, kls = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
padded_scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
rewards, non_score_reward, kls = self.compute_rewards(padded_scores, all_logprobs, ref_logprobs, masks)
timing["time/ppo/compute_rewards"] = time.time() - t

t = time.time()
Expand Down Expand Up @@ -821,8 +854,16 @@ def step(
train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)

discount = torch.tensor(
[self.config.gamma**i for i in range(padded_scores.shape[1])], device=padded_scores.device
).unsqueeze(0)
discounted_scores = padded_scores * discount
total_scores = torch.tensor([score[score != padding_value].sum() for score in discounted_scores]).to(
self.current_device
)

stats = self.record_step_stats(
scores=scores,
scores=total_scores,
logprobs=all_logprobs,
ref_logprobs=ref_logprobs,
non_score_reward=non_score_reward,
Expand Down Expand Up @@ -1087,7 +1128,7 @@ def compute_rewards(

Args:
scores (`torch.FloatTensor`):
Scores from the reward model, shape (`batch_size`)
Scores from the reward model, shape (`batch_size`, `max_response_length')
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
ref_logprobs (`torch.FloatTensor`):
Expand All @@ -1106,10 +1147,11 @@ def compute_rewards(
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
last_non_masked_index = mask.nonzero()[-1]

# get the unpadded score
score = score[: mask.sum()]
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
reward[mask.bool()] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)

Expand Down Expand Up @@ -1330,6 +1372,9 @@ def log_stats(
"""

# all gather stats

# sum to episodic rewards
rewards = [reward.sum() for reward in rewards]
Comment on lines +1376 to +1377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the discounted return

if not isinstance(rewards, torch.Tensor):
rewards = torch.tensor(rewards).to(self.current_device)
rewards = self.accelerator.gather(rewards).flatten()
Expand Down
Loading