diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 337c99af6a..10e0d0f343 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -22,6 +22,7 @@ import torch from datasets import Dataset, load_dataset +from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments from trl import DPOTrainer @@ -51,6 +52,10 @@ class ScriptArguments: ) label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + # lora parameters + use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"}) + peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) + peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) # instrumentation sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) report_to: Optional[str] = field( @@ -163,6 +168,16 @@ def split_prompt_and_responses(sample) -> Dict[str, str]: # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, ) + if script_args.use_peft: + peft_config = LoraConfig( + r=script_args.peft_lora_r, + lora_alpha=script_args.peft_lora_alpha, + bias="none", + task_type="CAUSAL_LM", + ) + else: + peft_config = None + # 5. initialize the DPO trainer dpo_trainer = DPOTrainer( model, @@ -176,6 +191,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]: max_target_length=script_args.max_target_length, max_prompt_length=script_args.max_prompt_length, generate_during_eval=True, + peft_config=peft_config, ) # 6. train diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index ef628497a3..edc691e691 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -51,6 +51,8 @@ def _init_dummy_dataset(self): "Which is the best programming language?", "Which is the best programming language?", "Which is the best programming language?", + "[INST] How is the stock price? [/INST]", + "[INST] How is the stock price? [/INST] ", ], "chosen": [ "hi nice to meet you", @@ -60,6 +62,8 @@ def _init_dummy_dataset(self): "Python", "Python", "Python", + "$46 as of 10am EST", + "46 as of 10am EST", ], "rejected": [ "leave me alone", @@ -69,15 +73,24 @@ def _init_dummy_dataset(self): "Javascript", "C++", "Java", + " $46 as of 10am EST", + " 46 as of 10am EST", ], } # fmt: on return Dataset.from_dict(dummy_dataset_dict) @parameterized.expand( - [["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]] + [ + ["gpt2", "sigmoid", True], + ["t5", "hinge", False], + ["gpt2", "ipo", False], + ["t5", "ipo", True], + ["gpt2", "kto", True], + ["t5", "kto", False], + ] ) - def test_dpo_trainer(self, name, loss_type): + def test_dpo_trainer(self, name, loss_type, pre_compute): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, @@ -109,6 +122,7 @@ def test_dpo_trainer(self, name, loss_type): tokenizer=tokenizer, train_dataset=dummy_dataset, eval_dataset=dummy_dataset, + precompute_ref_log_probs=pre_compute, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -146,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self): tokenizer=self.tokenizer, train_dataset=dummy_dataset, eval_dataset=dummy_dataset, + precompute_ref_log_probs=True, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -196,6 +211,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self): train_dataset=dummy_dataset, eval_dataset=dummy_dataset, peft_config=lora_config, + precompute_ref_log_probs=True, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -283,6 +299,7 @@ def test_dpo_lora_save(self): train_dataset=dummy_dataset, eval_dataset=dummy_dataset, peft_config=lora_config, + precompute_ref_log_probs=True, ) # train the model diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0015f42af0..8d8822751a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -16,13 +16,15 @@ import random import warnings from collections import defaultdict +from contextlib import nullcontext from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from accelerate.utils import is_deepspeed_available +from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset from torch.utils.data import DataLoader from transformers import ( @@ -76,7 +78,7 @@ class DPOTrainer(Trainer): label_pad_token_id (`int`, defaults to `-100`): The label pad token id. This argument is required if you want to use the default data collator. padding_value (`int`, defaults to `0`): - The padding value. This argument is required if you want to use the default data collator. + The padding value if it is different to the tokenizer's pad_token_id. truncation_mode (`str`, defaults to `keep_end`): The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. train_dataset (`datasets.Dataset`): @@ -110,11 +112,13 @@ class DPOTrainer(Trainer): compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + precompute_ref_log_probs (`bool`, defaults to `False`): + Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train + without the reference model and reduce the total GPU memory needed. model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the model from a string ref_model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the ref model from a string - """ def __init__( @@ -127,17 +131,14 @@ def __init__( args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, - padding_value: int = 0, + padding_value: int = None, truncation_mode: str = "keep_end", train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, max_length: Optional[int] = None, max_prompt_length: Optional[int] = None, @@ -147,6 +148,7 @@ def __init__( disable_dropout: bool = True, generate_during_eval: bool = False, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, ): @@ -243,7 +245,7 @@ def make_inputs_require_grad(module, input, output): if ref_model: self.ref_model = ref_model - elif self.is_peft_model: + elif self.is_peft_model or precompute_ref_log_probs: # The `model` with adapters turned off will be used as the reference model self.ref_model = None else: @@ -278,14 +280,9 @@ def make_inputs_require_grad(module, input, output): max_target_length = 128 data_collator = DPODataCollatorWithPadding( - tokenizer, - max_length=max_length, - max_prompt_length=max_prompt_length, + pad_token_id=tokenizer.pad_token_id, label_pad_token_id=label_pad_token_id, - padding_value=padding_value, - truncation_mode=truncation_mode, is_encoder_decoder=self.is_encoder_decoder, - max_target_length=max_target_length, ) if args.remove_unused_columns: @@ -309,7 +306,17 @@ def make_inputs_require_grad(module, input, output): self.max_length = max_length self.generate_during_eval = generate_during_eval self.label_pad_token_id = label_pad_token_id - self.padding_value = padding_value + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0: warnings.warn( @@ -322,6 +329,11 @@ def make_inputs_require_grad(module, input, output): self._stored_metrics = defaultdict(lambda: defaultdict(list)) + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row) + super().__init__( model=model, args=args, @@ -341,10 +353,17 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + if self.ref_model is None: - if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): + if not (self.is_peft_model or self.precompute_ref_log_probs): raise ValueError( - "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) else: if self.is_deepspeed_enabled: @@ -383,30 +402,340 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model.eval() return model - def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( + (reference_chosen_logp, reference_rejected_logp) + ) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + self.train_dataset = self.train_dataset.add_column( + name="reference_chosen_logps", column=all_reference_chosen_logps + ) + self.train_dataset = self.train_dataset.add_column( + name="reference_rejected_logps", column=all_reference_rejected_logps + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( + (reference_chosen_logp, reference_rejected_logp) + ) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) + eval_dataset = eval_dataset.add_column( + name="reference_rejected_logps", column=all_reference_rejected_logps + ) + + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: + """Tokenize a single row from a DPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # add BOS token to head of prompt + prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] + chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] + rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] + + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + rejected_tokens = self.tokenizer( + rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + prompt_tokens = self.tokenizer( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=batch["rejected_labels"] + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=batch["chosen_labels"] + ) + + return batch + + def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: + """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" + # compute reference logps + with torch.no_grad(): + if self.ref_model is None: + with self.accelerator.unwrap_model( + self.model + ).disable_adapter() if self.is_peft_model else nullcontext(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, padded_batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, padded_batch) + + return reference_chosen_logps, reference_rejected_logps + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor. Args: batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. """ concatenated_batch = {} - if self.is_encoder_decoder: + if is_encoder_decoder: max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) else: max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 concatenated_key = k.replace("chosen", "concatenated") concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 concatenated_key = k.replace("rejected", "concatenated") concatenated_batch[concatenated_key] = torch.cat( ( @@ -414,11 +743,13 @@ def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) - pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, - ).to(self.accelerator.device) + ).to(device=device) - if self.is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) - concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) return concatenated_batch @@ -490,11 +821,13 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards - def _get_batch_logps( - self, + @staticmethod + def get_batch_logps( logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, ) -> torch.FloatTensor: """Compute the log probabilities of the given labels under the given logits. @@ -509,13 +842,13 @@ def _get_batch_logps( if logits.shape[:-1] != labels.shape: raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - if not self.is_encoder_decoder: + if not is_encoder_decoder: labels = labels[:, 1:].clone() logits = logits[:, :-1, :] - loss_mask = labels != self.label_pad_token_id + loss_mask = labels != label_pad_token_id # dummy token; we'll ignore the losses on these tokens later - labels[labels == self.label_pad_token_id] = 0 + labels[labels == label_pad_token_id] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) @@ -531,7 +864,13 @@ def concatenated_forward( We do this to avoid doing two forward passes, because it's faster for FSDP. """ - concatenated_batch = self.concatenated_inputs(batch) + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) len_chosen = batch["chosen_labels"].shape[0] model_kwargs = ( @@ -546,12 +885,14 @@ def concatenated_forward( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], **model_kwargs, - ).logits.to(torch.float32) + ).logits - all_logps = self._get_batch_logps( + all_logps = self.get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, ) chosen_logps = all_logps[:len_chosen] @@ -562,7 +903,7 @@ def concatenated_forward( return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) - def get_batch_metrics( + def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], @@ -577,22 +918,28 @@ def get_batch_metrics( policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) - with torch.no_grad(): - if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).disable_adapter(): + + # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model + if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: + reference_chosen_logps = batch["reference_chosen_logps"] + reference_rejected_logps = batch["reference_rejected_logps"] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: ( reference_chosen_logps, reference_rejected_logps, _, _, - ) = self.concatenated_forward(self.model, batch) - else: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - ) = self.concatenated_forward(self.ref_model, batch) + ) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, @@ -625,7 +972,7 @@ def compute_loss( "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) - loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") # force log the metrics if self.accelerator.is_main_process: @@ -646,23 +993,27 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ pad_token_id=self.tokenizer.pad_token_id, ) - if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - reference_output = self.model.generate( + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, ) - else: - reference_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.tokenizer.pad_token_id, - ) policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) @@ -691,7 +1042,7 @@ def prediction_step( ignore_keys = [] with torch.no_grad(): - loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") # force log the metrics if self.accelerator.is_main_process: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3c29bedfb2..554346cfd1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -22,7 +22,7 @@ import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from transformers import DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback class AdaptiveKLController: @@ -272,175 +272,29 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: @dataclass class DPODataCollatorWithPadding: r""" - DPO DataCollator class that pads the inputs to the maximum length of the batch. + DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. Args: - tokenizer (`PreTrainedTokenizerBase`): - The tokenizer used for encoding the data. - model (Optional[`PreTrainedModel`]): - The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to - prepare the *decoder_input_ids*. - padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): - padding_strategy to pass to the tokenizer. - max_length (`Optional[int]`, `optional`, defaults to `None`): - The maximum length of the sequence to be processed. - max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): - The maximum length of the prompt to be processed. + pad_token_id (`int` defaults to 0): + The tokenizer's pad_token_id. label_pad_token_id (`int`, defaults to -100): The label used for masking. - padding_value (`int`, defaults to 0): - The value used for padding. is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): Whether or not you model has an encoder_decoder architecture. - max_target_length (`Optional[int]`, `optional`, defaults to `None`): - The maximum length of the target to be processed. Only useful for encoder-decoder architectures. - truncation_mode: (`str`, defaults to "keep_end"): - The truncation mode to use when truncating the prompt. """ - tokenizer: PreTrainedTokenizerBase - model: Optional[PreTrainedModel] = None - padding: Union[bool, str] = True - max_length: Optional[int] = None - max_prompt_length: Optional[int] = None + pad_token_id: int = 0 label_pad_token_id: int = -100 - padding_value: int = 0 - truncation_mode: str = "keep_end" is_encoder_decoder: Optional[bool] = False - max_target_length: Optional[int] = None - - def tokenize_batch_element( - self, - prompt: str, - chosen: str, - rejected: str, - ) -> Dict: - """Tokenize a single batch element. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation - in case the prompt + chosen or prompt + rejected responses is/are too long. First - we truncate the prompt; if we're still too long, we truncate the chosen/rejected. - - We also create the labels for the chosen/rejected responses, which are of length equal to - the sum of the length of the prompt and the chosen/rejected response, with - label_pad_token_id for the prompt tokens. - """ - batch = {} - - if not self.is_encoder_decoder: - chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) - rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) - prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) - - eos_token_id = self.tokenizer.eos_token_id - # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) - eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] - # attention mask these indices to eos_token_id - new_attention_mask = [ - 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) - ] - prompt_tokens["attention_mask"] = new_attention_mask - - # do the same for chosen and rejected - eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] - new_attention_mask_c = [ - 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) - ] - chosen_tokens["attention_mask"] = new_attention_mask_c - - eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] - new_attention_mask_r = [ - 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) - ] - rejected_tokens["attention_mask"] = new_attention_mask_r - - # add EOS token to end of prompt - chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) - chosen_tokens["attention_mask"].append(1) - - rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) - rejected_tokens["attention_mask"].append(1) - - longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: - if self.truncation_mode == "keep_start": - prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} - elif self.truncation_mode == "keep_end": - prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - - # if that's still too long, truncate the response - if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: - chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} - rejected_tokens = { - k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() - } - # Create labels - chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} - rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( - prompt_tokens["input_ids"] - ) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( - prompt_tokens["input_ids"] - ) - - for k, toks in { - "chosen": chosen_sequence_tokens, - "rejected": rejected_sequence_tokens, - "prompt": prompt_tokens, - }.items(): - for type_key, tokens in toks.items(): - if type_key == "token_type_ids": - continue - batch[f"{k}_{type_key}"] = tokens - - else: - chosen_tokens = self.tokenizer( - chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True - ) - rejected_tokens = self.tokenizer( - rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True - ) - prompt_tokens = self.tokenizer( - prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): - batch["rejected_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( - labels=batch["rejected_labels"] - ) - batch["chosen_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels( - labels=batch["chosen_labels"] - ) - - batch["prompt"] = prompt - batch["chosen"] = prompt + chosen - batch["rejected"] = prompt + rejected - batch["chosen_response_only"] = chosen - batch["rejected_response_only"] = rejected - - return batch - - def collate(self, batch): + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: # first, pad everything to the same length padded_batch = {} - for k in batch[0].keys(): + for k in features[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): if self.is_encoder_decoder: - to_pad = [torch.LongTensor(ex[k]) for ex in batch] + to_pad = [torch.LongTensor(ex[k]) for ex in features] if (k.startswith("prompt")) and (k.endswith("input_ids")): - padding_value = self.tokenizer.pad_token_id + padding_value = self.pad_token_id elif k.endswith("_attention_mask"): padding_value = 0 elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): @@ -451,15 +305,15 @@ def collate(self, batch): else: # adapted from https://stackoverflow.com/questions/73256206 if "prompt" in k: - to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features] else: - to_pad = [torch.LongTensor(ex[k]) for ex in batch] + to_pad = [torch.LongTensor(ex[k]) for ex in features] if k.endswith("_input_ids"): - padding_value = self.tokenizer.pad_token_id + padding_value = self.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): - padding_value = self.padding_value + padding_value = 0 else: raise ValueError(f"Unexpected key in batch '{k}'") @@ -467,25 +321,14 @@ def collate(self, batch): # for the prompt, flip back so padding is on left side if "prompt" in k: padded_batch[k] = padded_batch[k].flip(dims=[1]) + elif k.endswith("_logps"): + # the cached reference model logprobs + padded_batch[k] = torch.tensor([ex[k] for ex in features]) else: - padded_batch[k] = [ex[k] for ex in batch] + padded_batch[k] = [ex[k] for ex in features] return padded_batch - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - tokenized_batch = [] - - for feature in features: - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - batch_element = self.tokenize_batch_element(prompt, chosen, rejected) - tokenized_batch.append(batch_element) - - # return collated batch - return self.collate(tokenized_batch) - class ConstantLengthDataset(IterableDataset): """