From e6b853f936814b3bb2dd3cb5daaa37a02acb7988 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 29 Aug 2024 15:02:27 -0400 Subject: [PATCH 1/4] idefics2 enable_input_require_grads not aligned with disable_input_require_grads make peft+idefics2 checkpoints disable fail Signed-off-by: Wang, Yi --- .../models/idefics2/modeling_idefics2.py | 8 ++++ tests/test_modeling_common.py | 38 ++++++++++--------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 6acabad0635b3f..77f5c7755790b8 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1507,6 +1507,10 @@ def make_inputs_require_grads(module, input, output): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -1716,6 +1720,10 @@ def make_inputs_require_grads(module, input, output): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 30010cde9116dc..9d4dc9f3bb5c9e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -367,26 +367,28 @@ def test_gradient_checkpointing_enable_disable(self): self.assertFalse(model.is_gradient_checkpointing) # check enable works - model.gradient_checkpointing_enable() - self.assertTrue(model.is_gradient_checkpointing) - - # Loop over all modules and check that relevant modules have gradient_checkpointing set to True - for n, m in model.named_modules(): - if hasattr(m, "gradient_checkpointing"): - self.assertTrue( - m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" - ) + for _hf_peft_config_loaded in [True, False]: + model._hf_peft_config_loaded = _hf_peft_config_loaded + model.gradient_checkpointing_enable() + self.assertTrue(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) - # check disable works - model.gradient_checkpointing_disable() - self.assertFalse(model.is_gradient_checkpointing) + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) - # Loop over all modules and check that relevant modules have gradient_checkpointing set to False - for n, m in model.named_modules(): - if hasattr(m, "gradient_checkpointing"): - self.assertFalse( - m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" - ) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): From d5a36f3e4836a6103a7127c6df8a5871ed6ddb5c Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 5 Sep 2024 07:25:16 -0400 Subject: [PATCH 2/4] split test case Signed-off-by: Wang, Yi --- tests/test_modeling_common.py | 72 +++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 185280703e3a1d..4cdea68dc607a3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -382,28 +382,60 @@ def test_gradient_checkpointing_enable_disable(self): self.assertFalse(model.is_gradient_checkpointing) # check enable works - for _hf_peft_config_loaded in [True, False]: - model._hf_peft_config_loaded = _hf_peft_config_loaded - model.gradient_checkpointing_enable() - self.assertTrue(model.is_gradient_checkpointing) - - # Loop over all modules and check that relevant modules have gradient_checkpointing set to True - for n, m in model.named_modules(): - if hasattr(m, "gradient_checkpointing"): - self.assertTrue( - m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" - ) + model.gradient_checkpointing_enable() + self.assertTrue(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) - # check disable works - model.gradient_checkpointing_disable() - self.assertFalse(model.is_gradient_checkpointing) + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) - # Loop over all modules and check that relevant modules have gradient_checkpointing set to False - for n, m in model.named_modules(): - if hasattr(m, "gradient_checkpointing"): - self.assertFalse( - m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" - ) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + + def test_peft_gradient_checkpointing_enable_disable(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + # at init model should have gradient checkpointing disabled + model = model_class(config) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model._hf_peft_config_loaded = True + model.gradient_checkpointing_enable() + self.assertTrue(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): From 0c9f9d29c1fdc029ffb59310619fd7cb5fb52cc3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 11 Sep 2024 07:04:35 -0400 Subject: [PATCH 3/4] fix ci failure Signed-off-by: Wang, Yi --- src/transformers/modeling_utils.py | 5 ++++- tests/test_modeling_common.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2faa60210ed4d7..dad0744d428409 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1765,7 +1765,10 @@ def enable_input_require_grads(self): def make_inputs_require_grads(module, input, output): output.requires_grad_(True) - self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + if self.get_input_embeddings() is not None: + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + else: + raise NotImplementedError def disable_input_require_grads(self): """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4cdea68dc607a3..a65dbad1828bb3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -416,7 +416,11 @@ def test_peft_gradient_checkpointing_enable_disable(self): # check enable works model._hf_peft_config_loaded = True - model.gradient_checkpointing_enable() + try: + model.gradient_checkpointing_enable() + except NotImplementedError: + continue + self.assertTrue(model.is_gradient_checkpointing) # Loop over all modules and check that relevant modules have gradient_checkpointing set to True From 064179bdaefb8cf5f8bf6bfed09a25a1fff3c679 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 17 Sep 2024 11:16:09 -0400 Subject: [PATCH 4/4] refine test Signed-off-by: Wang, Yi --- src/transformers/modeling_utils.py | 5 +---- tests/models/speecht5/test_modeling_speecht5.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dad0744d428409..2faa60210ed4d7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1765,10 +1765,7 @@ def enable_input_require_grads(self): def make_inputs_require_grads(module, input, output): output.requires_grad_(True) - if self.get_input_embeddings() is not None: - self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - else: - raise NotImplementedError + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) def disable_input_require_grads(self): """ diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 7a8aab83272bc8..e13cf8dd56c3ef 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -239,6 +239,12 @@ def test_torchscript_output_hidden_state(self): def test_torchscript_simple(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + @require_torch class SpeechT5ForSpeechToTextTester: @@ -1743,6 +1749,12 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: