diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 4ebcda3190..8facddaf15 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -143,13 +143,13 @@ def __init__( raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") elif args.stop_token: if args.stop_token == "eos": - self.policy_model.generation_config.eos_token_id = processing_class.eos_token_id + self.policy_model.generation_config.eos_token_id = self.eos_token_id = processing_class.eos_token_id else: raise ValueError( f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." ) else: - self.policy_model.generation_config.eos_token_id = args.stop_token_id # either None or an integer + self.policy_model.generation_config.eos_token_id = self.eos_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -455,9 +455,9 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -712,9 +712,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True))