diff --git a/examples/scripts/ppo.py b/examples/scripts/ppo.py index f1f42ea0b2..d84c108382 100644 --- a/examples/scripts/ppo.py +++ b/examples/scripts/ppo.py @@ -190,14 +190,21 @@ def collator(data): query_tensors = batch["input_ids"] # Get response from gpt2 - response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs) + response_tensors, ref_response_tensors = ppo_trainer.generate( + query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs + ) batch["response"] = tokenizer.batch_decode(response_tensors) + batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors) # Compute sentiment score texts = [q + r for q, r in zip(batch["query"], batch["response"])] pipe_outputs = sentiment_pipe(texts, **sent_kwargs) rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs) + ref_rewards = [torch.tensor(output[1]["score"]) for output in ref_pipe_outputs] + batch["ref_rewards"] = ref_rewards # Run PPO step stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards) + ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"]) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index b091d11232..07f7814367 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -17,6 +17,7 @@ import time import typing import warnings +from contextlib import nullcontext from typing import Callable, List, Optional, Union import datasets @@ -235,6 +236,11 @@ def __init__( f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " f"architectures are: {SUPPORTED_ARCHITECTURES} " ) + self.optional_peft_ctx = ( + self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter + if self.is_peft_model + else nullcontext + ) if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): raise ValueError( @@ -423,6 +429,7 @@ def generate( length_sampler: Callable = None, batch_size: int = 4, return_prompt: bool = True, + generate_ref_response: bool = False, **generation_kwargs, ): """ @@ -440,19 +447,33 @@ def generate( Batch size used for generation, defaults to `4`. return_prompt (`bool`, *optional*): If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. + generate_ref_response (`bool`, *optional*): + If set to `True` the reference response is also generated, defaults to `False`. Returns: `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. """ - + if generate_ref_response: + ref_model = self.model if self.is_peft_model else self.ref_model if isinstance(query_tensor, List): - return self._generate_batched( + response = self._generate_batched( + self.model, query_tensor, length_sampler=length_sampler, batch_size=batch_size, return_prompt=return_prompt, **generation_kwargs, ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = self._generate_batched( + ref_model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) else: if len(query_tensor.shape) == 2: @@ -465,13 +486,22 @@ def generate( response = self.accelerator.unwrap_model(self.model).generate( input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) if not return_prompt and not self.is_encoder_decoder: - return response[:, query_tensor.shape[0] :] - return response + response = response[:, query_tensor.shape[0] :] + if generate_ref_response: + ref_response = ref_response[:, query_tensor.shape[0] :] + + if generate_ref_response: + return response, ref_response + return response def _generate_batched( self, + model: PreTrainedModelWrapper, query_tensors: List[torch.Tensor], length_sampler: Callable = None, batch_size: int = 4, @@ -508,7 +538,7 @@ def _generate_batched( return_tensors="pt", ).to(self.current_device) - generations = self.accelerator.unwrap_model(self.model).generate(**padded_inputs, **generation_kwargs) + generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs) for generation, mask in zip(generations, padded_inputs["attention_mask"]): if not self.is_encoder_decoder: @@ -681,23 +711,13 @@ def step( response_masks=response_masks, return_logits=full_kl_penalty, ) - # for when the model is a peft model - if self.is_peft_model and hasattr( - self.accelerator.unwrap_model(self.model).pretrained_model, - "disable_adapter", - ): - with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter(): - ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( - self.model, queries, responses, model_inputs, return_logits=full_kl_penalty - ) - elif self.is_peft_model and not hasattr(self.model.pretrained_model, "disable_adapter"): - raise ValueError( - "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." - ) - - else: + with self.optional_peft_ctx(): ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( - self.ref_model, queries, responses, model_inputs, return_logits=full_kl_penalty + self.model if self.is_peft_model else self.ref_model, + queries, + responses, + model_inputs, + return_logits=full_kl_penalty, ) timing["time/ppo/forward_pass"] = time.time() - t