diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index eff62650bb4..0fcb12b36fa 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -955,6 +955,20 @@ def test_with_extra_column(self): dataset = truncate_dataset(dataset, max_length) assert dataset.to_dict() == expected_output + def test_with_specified_columns(self): + examples = { + "prompt_ids": [[1, 2, 3], [6, 7], [12]], + "completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]], + } + dataset = Dataset.from_dict(examples) + max_length = 2 + expected_output = { + "prompt_ids": [[1, 2], [6, 7], [12]], + "completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]], + } + dataset = truncate_dataset(dataset, max_length, columns=["prompt_ids"]) + assert dataset.to_dict() == expected_output + class TestMaybeConvertToChatML(TrlTestCase): def test_with_conversations_key(self): diff --git a/trl/data_utils.py b/trl/data_utils.py index 454dd24af15..33723420025 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -714,7 +714,10 @@ def pack_dataset( def truncate_dataset( - dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None + dataset: DatasetType, + max_length: int, + columns: Union[str, list[str]] = "all", + map_kwargs: Optional[dict[str, Any]] = None, ) -> DatasetType: r""" Truncate sequences in a dataset to a specified `max_length`. @@ -724,6 +727,8 @@ def truncate_dataset( Dataset to truncate. max_length (`int`): Maximum sequence length to truncate to. + columns (`str` or `list[str]`, *optional*, defaults to `"all"`): + Which columns to truncate. If `"all"` (default), all columns are truncated. map_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the dataset's map method when truncating examples. @@ -749,32 +754,30 @@ def truncate_dataset( map_kwargs = {} if isinstance(dataset, Dataset): # Fast truncation with pyarrow - def truncate(examples): + def truncate(examples, columns): truncated_columns = [] for column in examples.columns: - if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): - column = pc.list_slice(column, 0, max_length) + if columns == "all" or column._name in columns: + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + column = pc.list_slice(column, 0, max_length) truncated_columns.append(column) return pa.Table.from_arrays(truncated_columns, names=examples.column_names) dataset = dataset.with_format("arrow") - dataset = dataset.map(truncate, batched=True, **map_kwargs) + dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns}) dataset = dataset.with_format(None) else: - def truncate(examples): + def truncate(examples, columns): truncated_examples = {} for key, column in examples.items(): - if column and isinstance(column[0], list): - column = [val[:max_length] for val in column] + if columns == "all" or key in columns: + if column and isinstance(column[0], list): + column = [val[:max_length] for val in column] truncated_examples[key] = column return truncated_examples - dataset = dataset.map( - truncate, - batched=True, - **map_kwargs, - ) + dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns}) return dataset diff --git a/trl/experimental/dpo/__init__.py b/trl/experimental/dpo/__init__.py new file mode 100644 index 00000000000..5bc1ec33016 --- /dev/null +++ b/trl/experimental/dpo/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dpo_config import DPOConfig +from .dpo_trainer import DPOTrainer diff --git a/trl/experimental/dpo/dpo_config.py b/trl/experimental/dpo/dpo_config.py new file mode 100644 index 00000000000..3f460d9e68a --- /dev/null +++ b/trl/experimental/dpo/dpo_config.py @@ -0,0 +1,212 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class DPOConfig(TrainingArguments): + r""" + Configuration class for the [`DPOTrainer`]. + + This class includes only the parameters that are specific to DPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`DPOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + + > Parameters that control the data preprocessing + + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + pad_token (`str`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt part of the sequence. If `None`, no truncation is applied. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion part of the sequence. If `None`, no truncation is applied. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + precompute_ref_log_probs (`bool`, *optional*, defaults to `True`): + Whether to precompute the reference model log probabilities for the entire training dataset before + training. This allows to save memory during training, as the reference model does not need to be kept in + memory. + + > Parameters that control the training + + loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the + [SLiC](https://huggingface.co/papers/2305.10425) paper. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `DPOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + + # Parameters that control the data preprocessing + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + pad_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={"help": "Maximum length of the prompt part of the sequence. If `None`, no truncation is applied."}, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion part of the sequence. If `None`, no truncation is applied." + }, + ) + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from " + "the right. If `None`, no truncation is applied." + }, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " + "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " + "structure." + }, + ) + pad_to_multiple_of: Optional[int] = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + precompute_ref_log_probs: bool = field( + default=True, + metadata={ + "help": "Whether to precompute the reference model log probabilities for the entire training dataset " + "before training. This allows to save memory during training, as the reference model does not need to be " + "kept in memory." + }, + ) + + # Parameters that control the training + loss_type: list[str] = field( + default_factory=lambda: ["sigmoid"], + metadata={ + "help": "Type of loss to use. Possible values are: `'sigmoid'`, `'hinge'`.", + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " + "the reference model." + }, + ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + # Normalize loss_type to string format for internal use + if hasattr(self.loss_type, "__len__") and len(self.loss_type) == 1: + self.loss_type = self.loss_type[0] + super().__post_init__() diff --git a/trl/experimental/dpo/dpo_trainer.py b/trl/experimental/dpo/dpo_trainer.py new file mode 100644 index 00000000000..c010fcaa69f --- /dev/null +++ b/trl/experimental/dpo/dpo_trainer.py @@ -0,0 +1,871 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from datasets.fingerprint import Hasher +from transformers import ( + AutoConfig, + AutoProcessor, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ...data_utils import extract_prompt, is_conversational, prepare_multimodal_messages, truncate_dataset +from ...models import get_act_offloading_ctx_manager, prepare_deepspeed, prepare_fsdp, prepare_peft_model +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + disable_dropout_in_model, + entropy_from_logits, + flush_left, + flush_right, + hash_module, + pad, + remove_none_values, + selective_log_softmax, +) +from .dpo_config import DPOConfig + + +if is_peft_available(): + from peft import PeftConfig, PeftModel + + +logger = logging.get_logger(__name__) + + +FLASH_ATTENTION_VARIANTS = { + "flash_attention_2", + "flash_attention_3", + "kernels-community/flash-attn", + "kernels-community/vllm-flash-attn3", + "kernels-community/flash-attn3", +} + + +def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]: + return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names + + +@dataclass +class DataCollatorForPreference(DataCollatorMixin): + """ + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing the keys `"prompt_ids"`, + `"chosen_ids"` and `"rejected_input_ids"`. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch + corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + - `"completion_mask"`: Tensor indicating the positions of the completion tokens, padded to the maximum length of + the batch. + + Optionally, the examples can contain a `"ref_chosen_logps"` and `"ref_rejected_logps"` keys, in which case the + returned dictionary will also contain these keys with the corresponding tensors. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.dpo_trainer import DataCollatorForPreference + + >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> examples = [{"prompt_ids": [1, 2, 3], {"chosen_ids": [4, 5], "rejected_ids": [6]}] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5], + [ 1, 2, 3, 6, 0]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1], + [1, 1, 1, 1, 0]]), + 'completion_mask': tensor([[0, 0, 0, 1, 1], + [0, 0, 0, 1, 0]])} + ``` + """ + + pad_token_id: int + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + prompt_chosen_ids = [example["prompt_ids"] + example["chosen_ids"] for example in examples] + prompt_rejected_ids = [example["prompt_ids"] + example["rejected_ids"] for example in examples] + chosen_attention_mask = [[1] * len(example["prompt_ids"] + example["chosen_ids"]) for example in examples] + rejected_attention_mask = [[1] * len(example["prompt_ids"] + example["rejected_ids"]) for example in examples] + chosen_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["chosen_ids"]) for example in examples] + rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples] + input_ids = prompt_chosen_ids + prompt_rejected_ids + attention_mask = chosen_attention_mask + rejected_attention_mask + completion_mask = chosen_mask + rejected_mask + + # Convert to tensor + input_ids = [torch.tensor(ids) for ids in input_ids] + attention_mask = [torch.tensor(m, dtype=torch.long) for m in attention_mask] + completion_mask = [torch.tensor(m, dtype=torch.long) for m in completion_mask] + if "ref_chosen_logps" in examples[0]: + ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) + if "ref_rejected_logps" in examples[0]: + ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) + + # Pad + output = {} + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["completion_mask"] = pad( + completion_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if "ref_chosen_logps" in examples[0]: + output["ref_chosen_logps"] = ref_chosen_logps + if "ref_rejected_logps" in examples[0]: + output["ref_rejected_logps"] = ref_rejected_logps + return output + + +class DPOTrainer(BaseTrainer): + """ + Trainer for Direct Preference Optimization (DPO) method. + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import DPOTrainer + + dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + + trainer = DPOTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. + If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss + as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. + args ([`DPOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and + [`~trainer.dpo_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. DPO supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss + function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) + used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`DPOConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean + `compute_result` argument. This will be triggered after the last eval batch to signal that the function + needs to calculate and return the global summary statistics rather than accumulating the batch-level + statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "dpo"] + _name = "DPO" + _paper = { + "title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model", + "id": "2305.18290", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{rafailov2023direct, + title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, + author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, + year = 2023, + booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, + url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, + editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, + }"""), + } + + def __init__( + self, + model: Union[str, PreTrainedModel], + args: Optional[DPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = DPOConfig(f"{model_name}-DPO") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str): # it's a str, but not "auto" + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `DPOConfig`." + ) + + # Data collator + self.padding_free = args.padding_free + use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS + if self.padding_free: + raise NotImplementedError("Padding-free training is not yet implemented.") + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if not use_flash_attention: + logger.warning( + "Padding-free training is enabled, but the attention implementation is not set to a supported " + "flash attention variant. Padding-free training flattens batches into a single sequence, and only " + "the following implementations are known to reliably support this: " + f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to " + "unexpected behavior. To ensure compatibility, set `attn_implementation` in the model " + "configuration to one of these supported options or verify that your attention mechanism can " + "handle flattened sequences." + ) + + if args.per_device_train_batch_size == 1: + logger.warning( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." + ) + + dataset_sample = next(iter(train_dataset)) + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + if self._is_vision_dataset and not self._is_vlm: + raise ValueError( + "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "model does not seem to be a vision-language model. Please check your model and dataset." + ) + + if data_collator is None and not self._is_vision_dataset: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + data_collator = DataCollatorForPreference( + pad_token_id=pad_token_id, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vision_dataset: + raise NotImplementedError("VLM training is not yet implemented.") + + # Training arguments + self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type] + self.beta = args.beta + + # Dataset + # Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly + # and done on the fly instead. + skip_prepare_dataset = self._is_vision_dataset + if not skip_prepare_dataset: + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # For deepspeed, fsdp or non-distributed models, create a reference model from scratch + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: DPOConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Extract the prompt if needed + first_example = next(iter(dataset)) + if "prompt" not in first_example: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" + dataset = dataset.map(extract_prompt, **map_kwargs) + + # Apply the chat template if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + return example + + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": processing_class.eos_token}, **map_kwargs) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class): + output = {} + if is_conversational(example): + if self._is_vlm: + prepare_multimodal_messages(example["prompt"], num_images=0) + prepare_multimodal_messages(example["completion"], num_images=0) + prompt_ids = processing_class.apply_chat_template( + example["prompt"], + tokenize=True, + add_generation_prompt=True, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_chosen_processed = processing_class.apply_chat_template( + example["prompt"] + example["chosen"], + return_dict=True, + tokenize=True, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_rejected_processed = processing_class.apply_chat_template( + example["prompt"] + example["rejected"], + return_dict=True, + tokenize=True, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids + prompt_chosen_processed = { + k: v[0] if isinstance(v[0], list) else v for k, v in prompt_chosen_processed.items() + } + prompt_rejected_processed = { + k: v[0] if isinstance(v[0], list) else v for k, v in prompt_rejected_processed.items() + } + prompt_chosen_ids = prompt_chosen_processed["input_ids"] + prompt_rejected_ids = prompt_rejected_processed["input_ids"] + else: + prompt_ids = processing_class(text=example["prompt"])["input_ids"] + prompt_chosen_ids = processing_class(text=example["prompt"] + example["chosen"])["input_ids"] + prompt_rejected_ids = processing_class(text=example["prompt"] + example["rejected"])["input_ids"] + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_chosen_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+chosen. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + if not prompt_rejected_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+rejected. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + output["prompt_ids"] = prompt_ids + output["chosen_ids"] = prompt_chosen_ids[len(prompt_ids) :] + output["rejected_ids"] = prompt_rejected_ids[len(prompt_ids) :] + return output + + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + + # Truncate + if args.max_prompt_length is not None: + raise NotImplementedError("Prompt truncation is not yet implemented.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating prompt in {dataset_name} dataset" + dataset = truncate_dataset( + dataset, args.max_prompt_length, columns=["prompt_ids"], map_kwargs=map_kwargs + ) + if args.max_completion_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating completions in {dataset_name} dataset" + dataset = truncate_dataset( + dataset, args.max_completion_length, columns=["chosen_ids", "rejected_ids"], map_kwargs=map_kwargs + ) + # For Liger kernel, ensure only the essential columns + if args.use_liger_kernel: + collator_expected_keys = {"input_ids", "completion_mask"} + column_names = get_dataset_column_names(dataset) + dataset = dataset.select_columns(collator_expected_keys.intersection(column_names)) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). + if self._signature_columns is None: + if self._is_vision_dataset: + self._signature_columns = ["prompt", "chosen", "rejected"] + else: + self._signature_columns = [ + "prompt_ids", + "chosen_ids", + "rejected_ids", + "ref_chosen_logps", + "ref_rejected_logps", + ] + + def train(self, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None, **kwargs): + if self.args.precompute_ref_log_probs: + self.train_dataset = self._precompute_ref_logps( + self.train_dataset, self.args.per_device_train_batch_size, "train" + ) + if self.eval_dataset is not None: + if isinstance(self.eval_dataset, dict): + self.eval_dataset = { + key: self._precompute_ref_logps(dataset, self.args.per_device_eval_batch_size, key) + for key, dataset in self.eval_dataset.items() + } + else: + self.eval_dataset = self._precompute_ref_logps( + self.eval_dataset, self.args.per_device_eval_batch_size, "eval" + ) + return super().train() + + def _precompute_ref_logps( + self, dataset: Union[Dataset, IterableDataset], batch_size: int, dataset_name: str + ) -> None: + def compute_ref_logps(examples, collator, max_length, truncation_mode): + examples = [dict(zip(examples.keys(), v)) for v in zip(*examples.values())] # dict[list] to list[dict] + inputs = collator(examples) + input_ids = inputs["input_ids"].to(self.model.device) + attention_mask = inputs["attention_mask"].to(self.model.device) + completion_mask = inputs["completion_mask"].to(self.model.device) + + # Truncate inputs + if max_length is not None: + if truncation_mode == "keep_start": + input_ids = input_ids[:, :max_length] + attention_mask = attention_mask[:, :max_length] + completion_mask = completion_mask[:, :max_length] + elif truncation_mode == "keep_end": + attention_mask, input_ids, completion_mask = flush_right( + attention_mask, input_ids, completion_mask + ) + input_ids = input_ids[:, -max_length:] + attention_mask = attention_mask[:, -max_length:] + completion_mask = completion_mask[:, -max_length:] + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + else: + raise ValueError( + f"Unsupported truncation mode: {truncation_mode}, expected 'keep_start' or 'keep_end'" + ) + + outputs = self.model(input_ids, attention_mak=attention_mask, use_cache=False) + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., 1:].contiguous() + shift_completion_mask = completion_mask[..., 1:].contiguous() + per_token_logps = selective_log_softmax(shift_logits, shift_labels) + per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens + logps = per_token_logps.sum(dim=1) # sum over sequence length + chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected] + return {"ref_chosen_logps": chosen_logps.tolist(), "ref_rejected_logps": rejected_logps.tolist()} + + # Normally, `map` creates a fingerprint based on the transform function and its arguments. However, the model’s + # produces a different fingerprint on each run, which prevents the cache from being used. To fix this, we + # manually compute a stable fingerprint for the model instead. + fn_kwargs = { + "collator": self.data_collator, + "max_length": self.args.max_length, + "truncation_mode": self.args.truncation_mode, + } + model_hash = hash_module(self.model) + dataset = dataset.map( + compute_ref_logps, + batched=True, + batch_size=batch_size, + fn_kwargs=fn_kwargs, + desc=f"Computing reference logps for {dataset_name} dataset", + new_fingerprint=Hasher.hash((dataset._fingerprint, fn_kwargs, model_hash)), + ) + return dataset + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ): + """ + Compute training loss and additionally compute token accuracies + """ + mode = "train" if self.model.training else "eval" + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + completion_mask = inputs["completion_mask"] + + # Truncate inputs + if self.args.max_length is not None: + if self.args.truncation_mode == "keep_start": + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + completion_mask = completion_mask[:, : self.args.max_length] + elif self.args.truncation_mode == "keep_end": + attention_mask, input_ids, completion_mask = flush_right(attention_mask, input_ids, completion_mask) + input_ids = input_ids[:, -self.args.max_length :] + attention_mask = attention_mask[:, -self.args.max_length :] + completion_mask = completion_mask[:, -self.args.max_length :] + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + else: + raise ValueError( + f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'" + ) + + outputs = model(input_ids, attention_mak=attention_mask, use_cache=False) + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., 1:].contiguous() + shift_completion_mask = completion_mask[..., 1:].contiguous() + per_token_logps = selective_log_softmax(shift_logits, shift_labels) + per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens + logps = per_token_logps.sum(dim=1) # sum over sequence length + chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected] + ref_chosen_logps, ref_rejected_logps = inputs["ref_chosen_logps"], inputs["ref_rejected_logps"] + + # Get the log ratios for the chosen and rejected responses + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + + loss = 0 + + for loss_type in self.loss_type: + if loss_type == "sigmoid": + per_sequence_loss = -F.logsigmoid(self.beta * chosen_logratios - self.beta * rejected_logratios) + + elif loss_type == "hinge": + per_sequence_loss = torch.relu(1 - (self.beta * chosen_logratios - self.beta * rejected_logratios)) + + loss += per_sequence_loss.mean() + + # Log the metrics + # Entropy + per_token_entropy = entropy_from_logits(shift_logits.detach()) + entropy = per_token_entropy[shift_completion_mask.bool()].mean() + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) + + # Number of tokens + if mode == "train": + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Average logits for chosen and rejected completions + chosen_logits, rejected_logits = shift_logits.detach().chunk(2, dim=0) + chosen_mask, rejected_mask = shift_completion_mask.chunk(2, dim=0) + total_chosen_logits = chosen_logits[chosen_mask.bool()].mean(-1) + total_chosen_tokens = chosen_mask.sum() + total_rejected_logits = rejected_logits[rejected_mask.bool()].mean(-1) + total_rejected_tokens = rejected_mask.sum() + total_chosen_logits = self.accelerator.gather_for_metrics(total_chosen_logits).sum().item() + total_chosen_tokens = self.accelerator.gather_for_metrics(total_chosen_tokens).sum().item() + total_rejected_logits = self.accelerator.gather_for_metrics(total_rejected_logits).sum().item() + total_rejected_tokens = self.accelerator.gather_for_metrics(total_rejected_tokens).sum().item() + avg_chosen_logits = total_chosen_logits / total_chosen_tokens if total_chosen_tokens > 0 else 0.0 + avg_rejected_logits = total_rejected_logits / total_rejected_tokens if total_rejected_tokens > 0 else 0.0 + self._metrics[mode]["logits/chosen"].append(avg_chosen_logits) + self._metrics[mode]["logits/rejected"].append(avg_rejected_logits) + + # Token accuracy for the chosen completions + predictions = chosen_logits.argmax(dim=-1) + chosen_mask = shift_completion_mask[: len(shift_completion_mask) // 2].bool() + chosen_labels = shift_labels[: len(shift_labels) // 2] + correct_predictions = (predictions == chosen_labels) & chosen_mask + total_tokens = chosen_mask.sum() + correct_tokens = correct_predictions.sum() + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + + # Rewards for chosen and rejected completions + chosen_rewards = self.beta * (chosen_logps.detach() - ref_chosen_logps) + rejected_rewards = self.beta * (rejected_logps.detach() - ref_rejected_logps) + agg_chosen_rewards = self.accelerator.gather(chosen_rewards) + agg_rejected_rewards = self.accelerator.gather(rejected_rewards) + self._metrics[mode]["rewards/chosen"].append(agg_chosen_rewards.mean().item()) + self._metrics[mode]["rewards/rejected"].append(agg_rejected_rewards.mean().item()) + + # Reward accuracy + reward_accuracies = (chosen_rewards > rejected_rewards).float() + agg_reward_accuracies = self.accelerator.gather(reward_accuracies) + self._metrics[mode]["rewards/accuracies"].append(agg_reward_accuracies.mean().item()) + + # Reward margins + margins = chosen_rewards - rejected_rewards + agg_margins = self.accelerator.gather(margins) + self._metrics[mode]["rewards/margins"].append(agg_margins.mean().item()) + + # Average log probabilities for chosen and rejected completions + self._metrics[mode]["logps/chosen"].append(self.accelerator.gather(chosen_logps).mean().item()) + self._metrics[mode]["logps/rejected"].append(self.accelerator.gather(rejected_logps).mean().item()) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs.update(metrics) + super().log(logs, start_time) + self._metrics[mode].clear() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index bfcc4b4c53e..19b13c0d72e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -51,7 +51,7 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_liger_kernel_available, is_peft_available -from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_extract_prompt from ..models import create_reference_model, prepare_deepspeed from ..models.utils import prepare_fsdp from .base_trainer import BaseTrainer @@ -649,6 +649,8 @@ def _prepare_dataset( map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + is_chat = is_conversational(next(iter(dataset))) + # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" @@ -669,6 +671,7 @@ def _prepare_dataset( "max_completion_length": args.max_completion_length, # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) "add_special_tokens": False, + "is_chat": is_chat, }, **map_kwargs, ) @@ -682,6 +685,7 @@ def tokenize_row( max_prompt_length: Optional[int] = None, max_completion_length: Optional[int] = None, add_special_tokens: bool = True, + is_chat: bool = False, ) -> dict[str, list[int]]: """ Tokenize a row of the dataset. @@ -728,8 +732,9 @@ def tokenize_row( prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids if tokenizer.eos_token_id is not None: prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] - chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] - rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + if not is_chat: + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] # Truncate prompt and completion sequences if max_prompt_length is not None: diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index cee2fb82ede..8f3c2876967 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1057,7 +1057,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo elif args.max_length is not None: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Truncating {dataset_name} dataset" - dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + dataset = truncate_dataset(dataset, args.max_length, map_kwargs=map_kwargs) # For Liger kernel, ensure only the essential columns if args.use_liger_kernel: collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 00c71c8c4c8..a796860b52f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -19,6 +19,7 @@ import random import socket import warnings +import zlib from collections.abc import Mapping, Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version @@ -29,7 +30,6 @@ import pandas as pd import torch import torch.nn.functional as F -import torch.utils.data import transformers from accelerate import Accelerator, PartialState, logging from accelerate.state import AcceleratorState @@ -1990,3 +1990,14 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: architecture = getattr(transformers, config.architectures[0]) model = architecture.from_pretrained(model_id, **kwargs) return model + + +def hash_module(module: torch.nn.Module) -> str: + h = zlib.adler32(b"") + for _, tensor in sorted(module.state_dict().items()): + tensor = tensor.cpu() + h = zlib.adler32(str(tensor.dtype).encode(), h) + if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): + tensor = tensor.to(torch.float32) + h = zlib.adler32(tensor.numpy().tobytes(), h) + return f"{h:08x}"