From 6fe29e5e617ddd5647f211ffe59c5e99a6bcd02d Mon Sep 17 00:00:00 2001 From: Nikita Surnachev Date: Fri, 26 Apr 2024 23:14:10 +0300 Subject: [PATCH 01/17] =?UTF-8?q?=F0=9F=A4=AB=20TR-DPO=20implementation=20?= =?UTF-8?q?baseline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/dpo_trainer.py | 18 +++++++++++++++- trl/trainer/utils.py | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 128a7b343a..3a4d2c9fd8 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from accelerate import PartialState +from accelerate import Accelerator, AcceleratorState, PartialState from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset from torch.utils.data import DataLoader @@ -44,6 +44,7 @@ from ..models import PreTrainedModelWrapper, create_reference_model from .utils import ( DPODataCollatorWithPadding, + SyncRefModelCallback, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, @@ -176,6 +177,9 @@ def __init__( ref_adapter_name: Optional[str] = None, reference_free: bool = False, force_use_ref_model: bool = False, + sync_ref_model: bool = False, + mixup_alpha: float = 1.0, + ref_model_sync_steps: int = 2, ): if model_init_kwargs is None: model_init_kwargs = {} @@ -415,12 +419,24 @@ def make_inputs_require_grad(module, input, output): raise ValueError( "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) + if sync_ref_model: + raise ValueError("Адаптеры не поддерживаем") else: if self.is_deepspeed_enabled: self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if sync_ref_model: + if precompute_ref_log_probs: + raise ValueError("Не работает с прекомпьют логпроб") + + if sync_ref_model: + self.add_callback(SyncRefModelCallback(self.accelerator, + self.ref_model, + mixup_alpha=mixup_alpha, + ref_model_sync_steps=ref_model_sync_steps)) + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 31e11f84a4..74c489558a 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -56,6 +56,49 @@ def update(self, current, n_steps): self.value *= mult +class SyncRefModelCallback(TrainerCallback): + + def __init__( + self, + ref_model: Union[PreTrainedModel, nn.Module], + accelerator: Optional[Accelerator], + mixup_alpha: float, + ref_model_sync_steps: int, + ): + self.mixup_alpha = mixup_alpha + self.ref_model_sync_steps = ref_model_sync_steps + self.accelerator = accelerator + self.ref_model = ref_model + + @staticmethod + def _sync_target_model(model, target_model, alpha): + for target_param, copy_param in zip( + target_model.parameters(), model.parameters() + ): + target_param.data.copy_( + (alpha * copy_param.data) + (1.0 - alpha) * target_param.data + ) + + @classmethod + def sync_target_model(cls, model, target_model, alpha): + if AcceleratorState().deepspeed_plugin.zero_stage == 3: + with deepspeed.zero.GatheredParameters( + list(model.parameters()), modifier_rank=0 + ): + if deepspeed.comm.get_rank() == 0: + cls._sync_target_model(model, target_model, alpha) + else: + cls._sync_target_model(model, target_model, alpha) + + def on_step_end(self, args, state, control, model): + if ref_model is not None and self.global_step % self.ref_model_sync_steps == 0: + if self.accelerator: + unwrapped_model = accelerator.unwrap_model(model) + self.sync_target_model(unwrapped_model, self.ref_model, self.mixup_alpha) + else: + self.sync_target_model(model, self.ref_model, self.mixup_alpha) + + class FixedKLController: """Fixed KL controller.""" From a98e2ecd56efb6dd7b21bf263cacaab417117521 Mon Sep 17 00:00:00 2001 From: Nikita Surnachev Date: Fri, 26 Apr 2024 23:25:55 +0300 Subject: [PATCH 02/17] fix comments --- trl/trainer/dpo_trainer.py | 4 ++-- trl/trainer/utils.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3a4d2c9fd8..3aeba77370 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -420,7 +420,7 @@ def make_inputs_require_grad(module, input, output): "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) if sync_ref_model: - raise ValueError("Адаптеры не поддерживаем") + raise ValueError("You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`.") else: if self.is_deepspeed_enabled: self.ref_model = self._prepare_deepspeed(self.ref_model) @@ -429,7 +429,7 @@ def make_inputs_require_grad(module, input, output): if sync_ref_model: if precompute_ref_log_probs: - raise ValueError("Не работает с прекомпьют логпроб") + raise ValueError("You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`.") if sync_ref_model: self.add_callback(SyncRefModelCallback(self.accelerator, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 74c489558a..9e58cac06f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -72,19 +72,13 @@ def __init__( @staticmethod def _sync_target_model(model, target_model, alpha): - for target_param, copy_param in zip( - target_model.parameters(), model.parameters() - ): - target_param.data.copy_( - (alpha * copy_param.data) + (1.0 - alpha) * target_param.data - ) + for target_param, copy_param in zip(target_model.parameters(), model.parameters()): + target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) @classmethod def sync_target_model(cls, model, target_model, alpha): if AcceleratorState().deepspeed_plugin.zero_stage == 3: - with deepspeed.zero.GatheredParameters( - list(model.parameters()), modifier_rank=0 - ): + with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0): if deepspeed.comm.get_rank() == 0: cls._sync_target_model(model, target_model, alpha) else: From 687a2e848b86c61a4d598ddd100b9e67cbf3c7c7 Mon Sep 17 00:00:00 2001 From: Nikita Surnachev Date: Fri, 26 Apr 2024 23:32:53 +0300 Subject: [PATCH 03/17] docs --- trl/trainer/dpo_trainer.py | 10 ++++++++-- trl/trainer/utils.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3aeba77370..2ae6844ef4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -138,6 +138,12 @@ class DPOTrainer(Trainer): If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. force_use_ref_model (`bool`, defaults to `False`): In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`. + sync_ref_model ('bool', defaults to `False`): + The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper + ref_model_mixup_alpha ('float', defaults to 1.0): + The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper + ref_model_sync_steps ('int', defaults to 2): + The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper """ _tag_names = ["trl", "dpo"] @@ -178,7 +184,7 @@ def __init__( reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, - mixup_alpha: float = 1.0, + ref_model_mixup_alpha: float = 1.0, ref_model_sync_steps: int = 2, ): if model_init_kwargs is None: @@ -434,7 +440,7 @@ def make_inputs_require_grad(module, input, output): if sync_ref_model: self.add_callback(SyncRefModelCallback(self.accelerator, self.ref_model, - mixup_alpha=mixup_alpha, + ref_model_mixup_alpha=ref_model_mixup_alpha, ref_model_sync_steps=ref_model_sync_steps)) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 9e58cac06f..210d5ec970 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -62,10 +62,10 @@ def __init__( self, ref_model: Union[PreTrainedModel, nn.Module], accelerator: Optional[Accelerator], - mixup_alpha: float, + ref_model_mixup_alpha: float, ref_model_sync_steps: int, ): - self.mixup_alpha = mixup_alpha + self.ref_model_mixup_alpha = ref_model_mixup_alpha self.ref_model_sync_steps = ref_model_sync_steps self.accelerator = accelerator self.ref_model = ref_model @@ -88,9 +88,9 @@ def on_step_end(self, args, state, control, model): if ref_model is not None and self.global_step % self.ref_model_sync_steps == 0: if self.accelerator: unwrapped_model = accelerator.unwrap_model(model) - self.sync_target_model(unwrapped_model, self.ref_model, self.mixup_alpha) + self.sync_target_model(unwrapped_model, self.ref_model, self.ref_model_mixup_alpha) else: - self.sync_target_model(model, self.ref_model, self.mixup_alpha) + self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha) class FixedKLController: From 9f6d7955085b6981396af3de36fe5bede3a0686b Mon Sep 17 00:00:00 2001 From: Nikita Surnachev Date: Fri, 26 Apr 2024 23:41:20 +0300 Subject: [PATCH 04/17] fix linters --- trl/trainer/dpo_trainer.py | 22 +++++++++++++++------- trl/trainer/utils.py | 20 +++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2ae6844ef4..1ceacd2cdd 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from accelerate import Accelerator, AcceleratorState, PartialState +from accelerate import PartialState from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset from torch.utils.data import DataLoader @@ -426,7 +426,9 @@ def make_inputs_require_grad(module, input, output): "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) if sync_ref_model: - raise ValueError("You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`.") + raise ValueError( + "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." + ) else: if self.is_deepspeed_enabled: self.ref_model = self._prepare_deepspeed(self.ref_model) @@ -435,13 +437,19 @@ def make_inputs_require_grad(module, input, output): if sync_ref_model: if precompute_ref_log_probs: - raise ValueError("You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`.") + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." + ) if sync_ref_model: - self.add_callback(SyncRefModelCallback(self.accelerator, - self.ref_model, - ref_model_mixup_alpha=ref_model_mixup_alpha, - ref_model_sync_steps=ref_model_sync_steps)) + self.add_callback( + SyncRefModelCallback( + self.accelerator, + self.ref_model, + ref_model_mixup_alpha=ref_model_mixup_alpha, + ref_model_sync_steps=ref_model_sync_steps, + ) + ) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 210d5ec970..245e003d8d 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -19,14 +19,15 @@ import numpy as np import torch -from accelerate import PartialState +from accelerate import Accelerator, AcceleratorState, PartialState +from accelerate.utils import is_deepspeed_available from rich.console import Console, Group from rich.live import Live from rich.panel import Panel from rich.progress import Progress from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase +from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase from transformers.trainer import TrainerCallback from transformers.trainer_utils import has_length @@ -38,6 +39,10 @@ from peft import LoraConfig, PeftConfig +if is_deepspeed_available(): + import deepspeed + + class AdaptiveKLController: """ Adaptive KL controller described in the paper: @@ -57,10 +62,9 @@ def update(self, current, n_steps): class SyncRefModelCallback(TrainerCallback): - def __init__( self, - ref_model: Union[PreTrainedModel, nn.Module], + ref_model: Union[PreTrainedModel, torch.nn.Module], accelerator: Optional[Accelerator], ref_model_mixup_alpha: float, ref_model_sync_steps: int, @@ -84,10 +88,12 @@ def sync_target_model(cls, model, target_model, alpha): else: cls._sync_target_model(model, target_model, alpha) - def on_step_end(self, args, state, control, model): - if ref_model is not None and self.global_step % self.ref_model_sync_steps == 0: + def on_step_end(self, args, state, control, **kwargs): + model: PreTrainedModel = kwargs["model"] + + if self.ref_model is not None and self.global_step % self.ref_model_sync_steps == 0: if self.accelerator: - unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) self.sync_target_model(unwrapped_model, self.ref_model, self.ref_model_mixup_alpha) else: self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha) From 4e0adbcdd900b01937cb75352a562ea652d7a402 Mon Sep 17 00:00:00 2001 From: Nikita Surnachev Date: Fri, 26 Apr 2024 23:52:24 +0300 Subject: [PATCH 05/17] test added --- tests/test_dpo_trainer.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 3b68334cd7..c4312b1016 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -299,6 +299,45 @@ def test_dpo_trainer_w_dataset_num_proc(self): trainer.train() + def test_tr_dpo_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + precompute_ref_log_probs=False, + sync_ref_model=True, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.equal(param, new_param) + @require_no_wandb def test_dpo_trainer_generate_during_eval_no_wandb(self): with tempfile.TemporaryDirectory() as tmp_dir: From f9184808cd5608f8f0c3120adb7cd0143d19544b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 09:20:00 +0200 Subject: [PATCH 06/17] move configs to DPOConfig --- tests/test_dpo_trainer.py | 6 +++--- trl/trainer/dpo_config.py | 9 +++++++++ trl/trainer/dpo_trainer.py | 12 ++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index f535042b79..a87af54943 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -305,7 +305,7 @@ def test_dpo_trainer_w_dataset_num_proc(self): def test_tr_dpo_trainer(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -313,6 +313,8 @@ def test_tr_dpo_trainer(self): gradient_accumulation_steps=4, learning_rate=9e-1, evaluation_strategy="steps", + precompute_ref_log_probs=False, + sync_ref_model=True, ) dummy_dataset = self._init_dummy_dataset() @@ -325,8 +327,6 @@ def test_tr_dpo_trainer(self): tokenizer=self.tokenizer, train_dataset=dummy_dataset, eval_dataset=dummy_dataset, - precompute_ref_log_probs=False, - sync_ref_model=True, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 3b4f392f51..8538b103f7 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -65,6 +65,12 @@ class DPOConfig(TrainingArguments): If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. force_use_ref_model (`bool`, defaults to `False`): In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`. + sync_ref_model ('bool', defaults to `False`): + The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. + ref_model_mixup_alpha ('float', defaults to 1.0): + The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. + ref_model_sync_steps ('int', defaults to 2): + The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. """ beta: float = 0.1 @@ -87,3 +93,6 @@ class DPOConfig(TrainingArguments): ref_adapter_name: Optional[str] = None reference_free: bool = False force_use_ref_model: bool = False + sync_ref_model: bool = False + ref_model_mixup_alpha: float = 1.0 + ref_model_sync_steps: int = 2 diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2aa388552d..abbbc77b48 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -164,9 +164,6 @@ def __init__( ref_adapter_name: Optional[str] = None, reference_free: bool = False, force_use_ref_model: bool = False, - sync_ref_model: bool = False, - ref_model_mixup_alpha: float = 1.0, - ref_model_sync_steps: int = 2, ): if model_init_kwargs is not None: warnings.warn( @@ -532,7 +529,7 @@ def make_inputs_require_grad(module, input, output): raise ValueError( "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) - if sync_ref_model: + if args.sync_ref_model: raise ValueError( "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." ) @@ -542,19 +539,18 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - if sync_ref_model: + if args.bf16sync_ref_model: if precompute_ref_log_probs: raise ValueError( "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." ) - if sync_ref_model: self.add_callback( SyncRefModelCallback( self.accelerator, self.ref_model, - ref_model_mixup_alpha=ref_model_mixup_alpha, - ref_model_sync_steps=ref_model_sync_steps, + ref_model_mixup_alpha=args.ref_model_mixup_alpha, + ref_model_sync_steps=args.ref_model_sync_steps, ) ) if self.loss_type == "bco_pair": From f3b409ea5dec8dde74079309c187572b6ed43a14 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 09:22:17 +0200 Subject: [PATCH 07/17] fix typo --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index abbbc77b48..4d4782482b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -539,7 +539,7 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - if args.bf16sync_ref_model: + if args.sync_ref_model: if precompute_ref_log_probs: raise ValueError( "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." From d087bcaf70c458ad0096b9a151fd0ed9e8edf56c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 09:30:42 +0200 Subject: [PATCH 08/17] add docs --- docs/source/dpo_trainer.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 6436cce315..d0de0e171a 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -113,6 +113,8 @@ The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier wh The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation using loss_type="sppo_hard" approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. +The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`. + ## Logging While training and evaluating we record the following reward metrics: From fb01df7f3ad34ede53ea18077fb940b85b0063cc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 09:38:08 +0200 Subject: [PATCH 09/17] fix import --- trl/trainer/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index fa08199a9a..1b4881753f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -20,7 +20,8 @@ import numpy as np import pandas as pd import torch -from accelerate import Accelerator, AcceleratorState, PartialState +from accelerate import Accelerator +from accelerate.state import AcceleratorState, PartialState from accelerate.utils import is_deepspeed_available from rich.console import Console, Group from rich.live import Live From b9aabc2766b79730db6aea0480a5947ae229c4db Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 09:48:09 +0200 Subject: [PATCH 10/17] use state.global_step --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1b4881753f..edff625eb5 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -99,7 +99,7 @@ def sync_target_model(cls, model, target_model, alpha): def on_step_end(self, args, state, control, **kwargs): model: PreTrainedModel = kwargs["model"] - if self.ref_model is not None and self.global_step % self.ref_model_sync_steps == 0: + if self.ref_model is not None and state.global_step % self.ref_model_sync_steps == 0: if self.accelerator: unwrapped_model = self.accelerator.unwrap_model(model) self.sync_target_model(unwrapped_model, self.ref_model, self.ref_model_mixup_alpha) From 3b63bdd0bd3ac06ef34e2d16a3ebcd6a031d3b01 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 10:19:04 +0200 Subject: [PATCH 11/17] fix order of arguments --- trl/trainer/dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 4d4782482b..4233cc9974 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -547,8 +547,8 @@ def make_inputs_require_grad(module, input, output): self.add_callback( SyncRefModelCallback( - self.accelerator, - self.ref_model, + ref_model=self.ref_model, + accelerator=self.accelerator, ref_model_mixup_alpha=args.ref_model_mixup_alpha, ref_model_sync_steps=args.ref_model_sync_steps, ) From 247a88b4e2ff5c90501a44bc1c0f3feddfb038b0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 May 2024 10:43:31 +0200 Subject: [PATCH 12/17] make sure plugins are not none --- trl/trainer/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index edff625eb5..64bcc3b3ed 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -89,7 +89,8 @@ def _sync_target_model(model, target_model, alpha): @classmethod def sync_target_model(cls, model, target_model, alpha): - if AcceleratorState().deepspeed_plugin.zero_stage == 3: + deepspeed_plugin = AcceleratorState().deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0): if deepspeed.comm.get_rank() == 0: cls._sync_target_model(model, target_model, alpha) From 99324090a04464d0879003cd3675d3386d5a28bd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 May 2024 12:39:20 +0200 Subject: [PATCH 13/17] Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 64bcc3b3ed..cdd09edf78 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -85,7 +85,7 @@ def __init__( @staticmethod def _sync_target_model(model, target_model, alpha): for target_param, copy_param in zip(target_model.parameters(), model.parameters()): - target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) + target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) @classmethod def sync_target_model(cls, model, target_model, alpha): From e230875130c65fbc901d13e388994e47a4a44b8c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 May 2024 12:50:07 +0200 Subject: [PATCH 14/17] Update trl/trainer/utils.py Co-authored-by: Benjamin Bossan --- trl/trainer/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index cdd09edf78..a022a960ef 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -102,10 +102,8 @@ def on_step_end(self, args, state, control, **kwargs): if self.ref_model is not None and state.global_step % self.ref_model_sync_steps == 0: if self.accelerator: - unwrapped_model = self.accelerator.unwrap_model(model) - self.sync_target_model(unwrapped_model, self.ref_model, self.ref_model_mixup_alpha) - else: - self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha) + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha) class FixedKLController: From b3335114200efbd606b9c7e6b5183e7eb4e23751 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 May 2024 13:08:34 +0200 Subject: [PATCH 15/17] checking that reference model weights have changed --- tests/test_dpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a87af54943..977d70343c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -329,6 +329,7 @@ def test_tr_dpo_trainer(self): eval_dataset=dummy_dataset, ) + # params of the ref model as its the same as the model previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() @@ -337,8 +338,8 @@ def test_tr_dpo_trainer(self): # check the params have changed for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - # check the params have changed - ignore 0 biases + new_param = trainer.ref_model.get_parameter(n) + # check the ref model's params have changed - ignore 0 biases if param.sum() != 0: assert not torch.equal(param, new_param) From 2f161b6a78e1f212650b1dddb63c669217315679 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 May 2024 12:58:51 +0200 Subject: [PATCH 16/17] sync_target_model as staticmethod --- tests/test_dpo_trainer.py | 2 ++ trl/trainer/dpo_config.py | 4 ++-- trl/trainer/dpo_trainer.py | 9 +-------- trl/trainer/utils.py | 16 ++++++---------- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 58f004ea67..8738ef1283 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -320,6 +320,8 @@ def test_tr_dpo_trainer(self): evaluation_strategy="steps", precompute_ref_log_probs=False, sync_ref_model=True, + ref_model_mixup_alpha=0.5, + ref_model_sync_steps=1, ) dummy_dataset = self._init_dummy_dataset() diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 85dc630968..701e8dba70 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -94,5 +94,5 @@ class DPOConfig(TrainingArguments): reference_free: bool = False force_use_ref_model: bool = False sync_ref_model: bool = False - ref_model_mixup_alpha: float = 1.0 - ref_model_sync_steps: int = 2 + ref_model_mixup_alpha: float = 0.9 + ref_model_sync_steps: int = 64 diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index eb75cbd777..5e900acf6e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -545,14 +545,7 @@ def make_inputs_require_grad(module, input, output): "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." ) - self.add_callback( - SyncRefModelCallback( - ref_model=self.ref_model, - accelerator=self.accelerator, - ref_model_mixup_alpha=args.ref_model_mixup_alpha, - ref_model_sync_steps=args.ref_model_sync_steps, - ) - ) + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index a022a960ef..88d42a908e 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -74,11 +74,7 @@ def __init__( self, ref_model: Union[PreTrainedModel, torch.nn.Module], accelerator: Optional[Accelerator], - ref_model_mixup_alpha: float, - ref_model_sync_steps: int, ): - self.ref_model_mixup_alpha = ref_model_mixup_alpha - self.ref_model_sync_steps = ref_model_sync_steps self.accelerator = accelerator self.ref_model = ref_model @@ -87,23 +83,23 @@ def _sync_target_model(model, target_model, alpha): for target_param, copy_param in zip(target_model.parameters(), model.parameters()): target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) - @classmethod - def sync_target_model(cls, model, target_model, alpha): + @staticmethod + def sync_target_model(model, target_model, alpha): deepspeed_plugin = AcceleratorState().deepspeed_plugin if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0): if deepspeed.comm.get_rank() == 0: - cls._sync_target_model(model, target_model, alpha) + SyncRefModelCallback._sync_target_model(model, target_model, alpha) else: - cls._sync_target_model(model, target_model, alpha) + SyncRefModelCallback._sync_target_model(model, target_model, alpha) def on_step_end(self, args, state, control, **kwargs): model: PreTrainedModel = kwargs["model"] - if self.ref_model is not None and state.global_step % self.ref_model_sync_steps == 0: + if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: if self.accelerator: model = self.accelerator.unwrap_model(model) - self.sync_target_model(model, self.ref_model, self.ref_model_mixup_alpha) + self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) class FixedKLController: From 9f50f7d48b927471c85112312d2ce41efba4141f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 May 2024 15:34:44 +0200 Subject: [PATCH 17/17] set reference model --- tests/test_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 8738ef1283..227e9b838c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -328,7 +328,7 @@ def test_tr_dpo_trainer(self): trainer = DPOTrainer( model=self.model, - ref_model=None, + ref_model=self.model, beta=0.1, args=training_args, tokenizer=self.tokenizer,