From 5333f46aa5c3d0b05111b675079a9950dacd5078 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Jul 2024 10:06:15 +0200 Subject: [PATCH 1/6] fix neftune_noise_alpha --- trl/trainer/sft_trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f59c0e787a..7816c8497d 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -304,15 +304,12 @@ def make_inputs_require_grad(module, input, output): args.dataset_batch_size = dataset_batch_size self.dataset_batch_size = args.dataset_batch_size - self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") - if neftune_noise_alpha is not None and self._trainer_supports_neftune: + if neftune_noise_alpha is not None: args.neftune_noise_alpha = neftune_noise_alpha warnings.warn( "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." ) - # self.neftune_noise_alpha is done at Trainer level - elif not self._trainer_supports_neftune: - self.neftune_noise_alpha = neftune_noise_alpha + self.neftune_noise_alpha = args.neftune_noise_alpha if dataset_text_field is not None: warnings.warn( @@ -445,14 +442,14 @@ def make_inputs_require_grad(module, input, output): @wraps(Trainer.train) def train(self, *args, **kwargs): # Activate neftune right before training. - if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + if self.neftune_noise_alpha is not None: self.model = self._trl_activate_neftune(self.model) output = super().train(*args, **kwargs) # After training we make sure to retrieve back the original forward pass method # for the embedding layer by removing the forward post hook. - if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + if self.neftune_noise_alpha is not None: unwrapped_model = unwrap_model(self.model) if is_peft_available() and isinstance(unwrapped_model, PeftModel): embeddings = unwrapped_model.base_model.model.get_input_embeddings() From 09549fc57e5dd9675cfeb6aec2164e27e0c7bb61 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 17 Jul 2024 11:10:26 +0200 Subject: [PATCH 2/6] del neftune_noise_alpha first --- trl/trainer/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 7816c8497d..8659880838 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -456,8 +456,8 @@ def train(self, *args, **kwargs): else: embeddings = unwrapped_model.get_input_embeddings() - self.neftune_hook_handle.remove() del embeddings.neftune_noise_alpha + self.neftune_hook_handle.remove() return output From 1ade054e04e3eb78e54ac66b9b98e38684ef3520 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 18 Jul 2024 00:15:01 +0200 Subject: [PATCH 3/6] check len after removing handle --- tests/test_sft_trainer.py | 4 ++-- trl/trainer/sft_trainer.py | 10 +--------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8775eb0b90..642931223a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -859,12 +859,12 @@ def test_sft_trainer_with_model_neftune(self): assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0 trainer.neftune_hook_handle.remove() + assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 trainer.train() # Make sure forward pass works fine _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) - assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 @require_peft def test_peft_sft_trainer_str(self): @@ -1018,6 +1018,7 @@ def test_peft_sft_trainer_neftune(self): assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0 trainer.neftune_hook_handle.remove() + assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 trainer.train() @@ -1030,7 +1031,6 @@ def test_peft_sft_trainer_neftune(self): # Make sure forward pass works fine to check if embeddings forward is not broken. _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) - assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 @require_peft def test_peft_sft_trainer_tag(self): diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8659880838..9d23d46f5a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -450,13 +450,6 @@ def train(self, *args, **kwargs): # After training we make sure to retrieve back the original forward pass method # for the embedding layer by removing the forward post hook. if self.neftune_noise_alpha is not None: - unwrapped_model = unwrap_model(self.model) - if is_peft_available() and isinstance(unwrapped_model, PeftModel): - embeddings = unwrapped_model.base_model.model.get_input_embeddings() - else: - embeddings = unwrapped_model.get_input_embeddings() - - del embeddings.neftune_noise_alpha self.neftune_hook_handle.remove() return output @@ -654,6 +647,5 @@ def _trl_activate_neftune(self, model): embeddings = unwrapped_model.get_input_embeddings() embeddings.neftune_noise_alpha = self.neftune_noise_alpha - hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) - self.neftune_hook_handle = hook_handle + self.neftune_hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) return model From 1420dcadda94d94dee04e9e9d9960b9ff03e4bbb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 17 Sep 2024 14:15:55 +0200 Subject: [PATCH 4/6] make sure we do not load twice --- tests/test_sft_trainer.py | 4 ++-- trl/trainer/sft_trainer.py | 32 +++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 57dcdd8b06..79315e64f2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -843,12 +843,12 @@ def test_sft_trainer_with_model_neftune(self): assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0 trainer.neftune_hook_handle.remove() - assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 trainer.train() # Make sure forward pass works fine _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 @require_peft def test_peft_sft_trainer_str(self): @@ -1009,7 +1009,6 @@ def test_peft_sft_trainer_neftune(self): assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0 trainer.neftune_hook_handle.remove() - assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 trainer.train() @@ -1022,6 +1021,7 @@ def test_peft_sft_trainer_neftune(self): # Make sure forward pass works fine to check if embeddings forward is not broken. _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0 @require_peft def test_peft_sft_trainer_tag(self): diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8cbc552e78..cc68b8d5a1 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -307,12 +307,15 @@ def make_inputs_require_grad(module, input, output): args.dataset_batch_size = dataset_batch_size self.dataset_batch_size = args.dataset_batch_size - if neftune_noise_alpha is not None: + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + if neftune_noise_alpha is not None and self._trainer_supports_neftune: args.neftune_noise_alpha = neftune_noise_alpha warnings.warn( "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." ) - self.neftune_noise_alpha = args.neftune_noise_alpha + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha if dataset_text_field is not None: warnings.warn( @@ -424,16 +427,27 @@ def make_inputs_require_grad(module, input, output): @wraps(Trainer.train) def train(self, *args, **kwargs): - # Activate neftune right before training. + # Activate neftune right before training if it's not already activated if self.neftune_noise_alpha is not None: - self.model = self._trl_activate_neftune(self.model) + if not self._trainer_supports_neftune or not hasattr( + self.model.get_input_embeddings(), "neftune_noise_alpha" + ): + self.model = self._trl_activate_neftune(self.model) output = super().train(*args, **kwargs) # After training we make sure to retrieve back the original forward pass method # for the embedding layer by removing the forward post hook. if self.neftune_noise_alpha is not None: - self.neftune_hook_handle.remove() + if not self._trainer_supports_neftune or hasattr(self.model.get_input_embeddings(), "neftune_noise_alpha"): + unwrapped_model = unwrap_model(self.model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha return output @@ -641,6 +655,10 @@ def _trl_activate_neftune(self, model): else: embeddings = unwrapped_model.get_input_embeddings() - embeddings.neftune_noise_alpha = self.neftune_noise_alpha - self.neftune_hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + if not hasattr(embeddings, "neftune_noise_alpha"): + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + else: + warnings.warn("NEFTune appears to be already activated. Skipping activation.") return model From 6bfb63d272dec28ef65259129dbcf3a029b339a4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 19 Sep 2024 09:41:57 +0200 Subject: [PATCH 5/6] Update trl/trainer/sft_trainer.py Co-authored-by: lewtun --- trl/trainer/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index cc68b8d5a1..8d25f12a7d 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -660,5 +660,5 @@ def _trl_activate_neftune(self, model): hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) self.neftune_hook_handle = hook_handle else: - warnings.warn("NEFTune appears to be already activated. Skipping activation.") + warnings.warn("NEFTune is already activated. Skipping activation.") return model From 0e5a1affabb24c1dd9cbcbb578cf5857fbc37788 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 19 Sep 2024 11:27:07 +0200 Subject: [PATCH 6/6] remove neftune from SFTTrainer as the superclass has it now --- tests/test_sft_trainer.py | 4 +-- trl/trainer/sft_config.py | 5 --- trl/trainer/sft_trainer.py | 63 ++++---------------------------------- trl/trainer/utils.py | 29 ------------------ 4 files changed, 8 insertions(+), 93 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 79315e64f2..a8b7719db3 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -828,7 +828,7 @@ def test_sft_trainer_with_model_neftune(self): eval_dataset=self.eval_dataset, ) - trainer.model = trainer._trl_activate_neftune(trainer.model) + trainer.model = trainer._activate_neftune(trainer.model) device = trainer.model.get_input_embeddings().weight.device trainer.model.train() @@ -992,7 +992,7 @@ def test_peft_sft_trainer_neftune(self): peft_config=peft_config, ) - trainer.model = trainer._trl_activate_neftune(trainer.model) + trainer.model = trainer._activate_neftune(trainer.model) assert isinstance(trainer.model, PeftModel) diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 9b076d3c40..55b3ba9adf 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -40,10 +40,6 @@ class SFTConfig(TrainingArguments): dataset_batch_size (`Union[int, None]`, *optional*, defaults to `1000`): Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is None`, tokenizes the full dataset as a single batch. - neftune_noise_alpha (`Optional[float]`, *optional*, defaults to `None`): - Scale of the noise for NEFTune embeddings. The [NEFTune paper](https://huggingface.co/papers/2310.05914) - suggests using values between `5` and `15`. If set to `None`, NEFTune is not activated. Activating NEFTune - can significantly improve model performance for instruction fine-tuning. model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a string. @@ -65,7 +61,6 @@ class SFTConfig(TrainingArguments): max_seq_length: Optional[int] = None dataset_num_proc: Optional[int] = None dataset_batch_size: int = 1000 - neftune_noise_alpha: Optional[float] = None model_init_kwargs: Optional[Dict[str, Any]] = None dataset_kwargs: Optional[Dict[str, Any]] = None eval_packing: Optional[bool] = None diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8d25f12a7d..c645734781 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -34,7 +34,6 @@ PreTrainedTokenizerBase, Trainer, ) -from transformers.modeling_utils import unwrap_model from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available @@ -45,7 +44,6 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, - neftune_post_forward_hook, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, ) @@ -154,6 +152,12 @@ def __init__( args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) args = SFTConfig(**args_as_dict) + if neftune_noise_alpha is not None: + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." + ) + args.neftune_noise_alpha = neftune_noise_alpha + if model_init_kwargs is not None: warnings.warn( "You passed `model_init_kwargs` to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." @@ -307,16 +311,6 @@ def make_inputs_require_grad(module, input, output): args.dataset_batch_size = dataset_batch_size self.dataset_batch_size = args.dataset_batch_size - self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") - if neftune_noise_alpha is not None and self._trainer_supports_neftune: - args.neftune_noise_alpha = neftune_noise_alpha - warnings.warn( - "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." - ) - # self.neftune_noise_alpha is done at Trainer level - elif not self._trainer_supports_neftune: - self.neftune_noise_alpha = neftune_noise_alpha - if dataset_text_field is not None: warnings.warn( "You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." @@ -425,32 +419,6 @@ def make_inputs_require_grad(module, input, output): elif self.args.max_steps == -1 and args.packing: self.train_dataset.infinite = False - @wraps(Trainer.train) - def train(self, *args, **kwargs): - # Activate neftune right before training if it's not already activated - if self.neftune_noise_alpha is not None: - if not self._trainer_supports_neftune or not hasattr( - self.model.get_input_embeddings(), "neftune_noise_alpha" - ): - self.model = self._trl_activate_neftune(self.model) - - output = super().train(*args, **kwargs) - - # After training we make sure to retrieve back the original forward pass method - # for the embedding layer by removing the forward post hook. - if self.neftune_noise_alpha is not None: - if not self._trainer_supports_neftune or hasattr(self.model.get_input_embeddings(), "neftune_noise_alpha"): - unwrapped_model = unwrap_model(self.model) - if is_peft_available() and isinstance(unwrapped_model, PeftModel): - embeddings = unwrapped_model.base_model.model.get_input_embeddings() - else: - embeddings = unwrapped_model.get_input_embeddings() - - self.neftune_hook_handle.remove() - del embeddings.neftune_noise_alpha - - return output - @wraps(Trainer.push_to_hub) def push_to_hub( self, @@ -643,22 +611,3 @@ def data_generator(constant_length_iterator): raise ValueError( "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`." ) - - def _trl_activate_neftune(self, model): - r""" - Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://huggingface.co/papers/2310.05914 - Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts. - """ - unwrapped_model = unwrap_model(model) - if is_peft_available() and isinstance(unwrapped_model, PeftModel): - embeddings = unwrapped_model.base_model.model.get_input_embeddings() - else: - embeddings = unwrapped_model.get_input_embeddings() - - if not hasattr(embeddings, "neftune_noise_alpha"): - embeddings.neftune_noise_alpha = self.neftune_noise_alpha - hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) - self.neftune_hook_handle = hook_handle - else: - warnings.warn("NEFTune is already activated. Skipping activation.") - return model diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index b6fe2cd426..2e2cca93d0 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -826,35 +826,6 @@ def get_stats(self): return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} -def neftune_post_forward_hook(module, input, output): - """ - Implements the NEFTune forward pass for the model using forward hooks. Note this works only for - torch.nn.Embedding layers. This method is slightly adapted from the original source code - that can be found here: https://github.com/neelsjain/NEFTune - - Simply add it to your model as follows: - ```python - model = ... - model.embed_tokens.neftune_noise_alpha = 0.1 - model.embed_tokens.register_forward_hook(neftune_post_forward_hook) - ``` - - Args: - module (`torch.nn.Module`): - The embedding module where the hook is attached. Note that you need to set - `module.neftune_noise_alpha` to the desired noise alpha value. - input (`torch.Tensor`): - The input tensor to the model. - output (`torch.Tensor`): - The output tensor of the model (i.e. the embeddings). - """ - if module.training: - dims = torch.tensor(output.size(1) * output.size(2)) - mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) - output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) - return output - - def peft_module_casting_to_bf16(model): for name, module in model.named_modules(): if isinstance(module, torch.nn.LayerNorm) or "norm" in name: