Skip to content

Commit

Permalink
really don't modify args
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 22, 2025
1 parent 7b8e66f commit 57cc809
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ def __init__(
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
elif args.stop_token:
if args.stop_token == "eos":
self.policy_model.generation_config.eos_token_id = processing_class.eos_token_id
self.policy_model.generation_config.eos_token_id = self.eos_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
)
else:
self.policy_model.generation_config.eos_token_id = args.stop_token_id # either None or an integer
self.policy_model.generation_config.eos_token_id = self.eos_token_id = args.stop_token_id # None or int

# peft support
if not is_peft_available() and peft_config is not None:
Expand Down Expand Up @@ -455,9 +455,9 @@ def repeat_generator():

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)

# Response Processing 2. run reward model on the truncated responses
Expand Down Expand Up @@ -712,9 +712,9 @@ def generate_completions(self, sampling: bool = False):
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
Expand Down

0 comments on commit 57cc809

Please sign in to comment.