diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index f57bdd27fee6d9..08ada424ea77b4 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1256,6 +1256,10 @@ def make_inputs_require_grads(module, input, output): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -1466,6 +1470,10 @@ def make_inputs_require_grads(module, input, output): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 7a8aab83272bc8..e13cf8dd56c3ef 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -239,6 +239,12 @@ def test_torchscript_output_hidden_state(self): def test_torchscript_simple(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + @require_torch class SpeechT5ForSpeechToTextTester: @@ -1743,6 +1749,12 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6a0afc60f85567..a65dbad1828bb3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -403,6 +403,44 @@ def test_gradient_checkpointing_enable_disable(self): m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" ) + def test_peft_gradient_checkpointing_enable_disable(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + # at init model should have gradient checkpointing disabled + model = model_class(config) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model._hf_peft_config_loaded = True + try: + model.gradient_checkpointing_enable() + except NotImplementedError: + continue + + self.assertTrue(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()