From fed5e541ac7755a11a06ec649ed9fe35b006fe85 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 1 Nov 2023 13:32:29 +0100 Subject: [PATCH] Fix disk offload tests + weight sharing issues --- src/transformers/modeling_utils.py | 4 ++- src/transformers/models/bart/modeling_bart.py | 5 +++ .../modeling_bigbird_pegasus.py | 5 +++ .../models/longt5/modeling_longt5.py | 14 ++++++++ .../models/m2m_100/modeling_m2m_100.py | 5 +++ .../models/nllb_moe/modeling_nllb_moe.py | 5 +++ .../models/plbart/modeling_plbart.py | 5 +++ .../seamless_m4t/modeling_seamless_m4t.py | 5 +++ .../modeling_switch_transformers.py | 14 ++++++++ src/transformers/models/t5/modeling_t5.py | 19 ++++++++++ src/transformers/models/umt5/modeling_umt5.py | 23 ++++++++++++ tests/models/vitdet/test_modeling_vitdet.py | 6 +++- tests/models/whisper/test_modeling_whisper.py | 6 +++- tests/test_modeling_common.py | 36 +++++++++++++++++-- 14 files changed, 147 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ccb9073aef1230..1a4769ae105816 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4520,7 +4520,9 @@ def expand_device_map(device_map, param_names): """ new_device_map = {} for module, device in device_map.items(): - new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")}) + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) return new_device_map diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c271aabcb4d111..60ec557eba7fdc 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1125,6 +1125,11 @@ def __init__(self, config: BartConfig): # Initialize weights and apply final processing self.post_init() + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_input_embeddings(self): return self.shared diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index cc69e31e5a7fdb..b9a84a869dac4f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2312,6 +2312,11 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 91e584d80d3fb3..0ae7cedea00b91 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1783,6 +1783,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1937,6 +1942,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -2170,6 +2180,10 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ed17796d27dc8e..4e5004fc98ffd6 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1103,6 +1103,11 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 1c08a70875e1d3..22708f1125224d 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1471,6 +1471,11 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 79b5c09cba2fef..4d8fe161f806e5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1084,6 +1084,11 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 3fe519d2d2592c..62d1f3e21f9ad0 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -4125,6 +4125,11 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + @add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING) def forward( self, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index e00a0147e420f8..07c96a5aa82860 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1329,6 +1329,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1505,6 +1510,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1807,6 +1817,10 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3748e5af778f47..ff8e6609b94df3 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1414,6 +1414,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -1620,6 +1625,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1920,6 +1930,10 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder @@ -2152,6 +2166,11 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + def get_encoder(self): return self.encoder diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index bfcbfb13eb4518..220aff273bc65a 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -973,6 +973,12 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder def get_encoder(self): return self.encoder @@ -1142,6 +1148,12 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @@ -1380,6 +1392,11 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder def get_encoder(self): return self.encoder @@ -1615,6 +1632,12 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder def get_encoder(self): return self.encoder diff --git a/tests/models/vitdet/test_modeling_vitdet.py b/tests/models/vitdet/test_modeling_vitdet.py index d6ffd03cbd7f79..361e563d58d448 100644 --- a/tests/models/vitdet/test_modeling_vitdet.py +++ b/tests/models/vitdet/test_modeling_vitdet.py @@ -182,7 +182,11 @@ def test_cpu_offload(self): # TODO: Fix me (once this model gets more usage) @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.") - def test_disk_offload(self): + def test_disk_offload_bin(self): + super().test_disk_offload() + + @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.") + def test_disk_offload_safetensors(self): super().test_disk_offload() # TODO: Fix me (once this model gets more usage) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9bb835360887c3..7c012b977e107a 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1787,7 +1787,11 @@ def test_cpu_offload(self): pass @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") - def test_disk_offload(self): + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): pass @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3c481007472882..595c72cda6fd2f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2578,7 +2578,7 @@ def check_device_map_is_respected(self, model, device_map): @require_accelerate @mark.accelerate_tests @require_torch_gpu - def test_disk_offload(self): + def test_disk_offload_bin(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -2593,7 +2593,7 @@ def test_disk_offload(self): model_size = compute_module_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) + model.cpu().save_pretrained(tmp_dir, safe_serialization=False) with self.assertRaises(ValueError): max_size = int(self.model_split_percents[0] * model_size) @@ -2613,6 +2613,38 @@ def test_disk_offload(self): self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_accelerate + @mark.accelerate_tests + @require_torch_gpu + def test_disk_offload_safetensors(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + torch.manual_seed(0) + base_output = model(**inputs_dict_class) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + + # This doesn't error out as it's in safetensors and doesn't need an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict_class) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_accelerate @mark.accelerate_tests @require_torch_gpu