diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 384876fb6de239..9a6c29c27bdf63 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1736,6 +1736,41 @@ def save_pretrained( for ignore_key in self._keys_to_ignore_on_save: if ignore_key in state_dict.keys(): del state_dict[ignore_key] + if safe_serialization: + # Safetensors does not allow tensor aliasing. + # We're going to remove aliases before saving + ptrs = collections.defaultdict(list) + for name, tensor in state_dict.items(): + ptrs[tensor.data_ptr()].append(name) + + # These are all the pointers of shared tensors. + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + warn_names = set() + for names in shared_ptrs.values(): + # Removing the keys which are declared as known duplicates on + # load. This allows to make sure the name which is kept is consistent. + if self._keys_to_ignore_on_load_missing is not None: + for name in names: + matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing) + if matches_pattern and name in state_dict: + del state_dict[name] + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found = 0 + for name in names: + if name in state_dict: + found += 1 + if found > 1: + del state_dict[name] + warn_names.add(name) + if len(warn_names) > 0: + logger.warning_once( + f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + ) # Shard the model if it is too big. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME @@ -2813,6 +2848,11 @@ def _fix_key(key): missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + # Some tensors maybe have been already filled by another key (tied weights). + existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict} + missing_keys = [ + k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs + ] # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_missing is not None: diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 46f0c9b11ce498..9b00274a0b14ca 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1238,8 +1238,28 @@ def __init__(self, config: Blip2Config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) def get_text_features( diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 6fd2e8fdd18412..eabc6e5e690d34 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput): def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + return nn.ModuleList([module for i in range(N)]) def inverse_sigmoid(x, eps=1e-5): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0c1a189b7a3ee6..c3f5285441bc60 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -609,8 +609,6 @@ def custom_forward(*inputs): class LlamaForCausalLM(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) diff --git a/src/transformers/models/pix2struct/configuration_pix2struct.py b/src/transformers/models/pix2struct/configuration_pix2struct.py index 8642602cf97db5..dead3d8a042413 100644 --- a/src/transformers/models/pix2struct/configuration_pix2struct.py +++ b/src/transformers/models/pix2struct/configuration_pix2struct.py @@ -357,9 +357,10 @@ def __init__( initializer_factor=1.0, initializer_range=0.02, is_vqa=False, + tie_word_embeddings=False, **kwargs, ): - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) if text_config is None: text_config = {} diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f71366d2183829..030555aece7365 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,6 +27,7 @@ import unittest import unittest.mock as mock import warnings +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -1626,6 +1627,41 @@ def check_same_values(layer_1, layer_2): # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) + @require_safetensors + def test_can_use_safetensors(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model_tied = model_class(config) + with tempfile.TemporaryDirectory() as d: + try: + model_tied.save_pretrained(d, safe_serialization=True) + except Exception as e: + raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}") + + model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) + # Checking the state dicts are correct + reloaded_state = model_reloaded.state_dict() + for k, v in model_tied.state_dict().items(): + self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded") + torch.testing.assert_close( + v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" + ) + + # Checking the tensor sharing are correct + ptrs = defaultdict(list) + for k, v in model_tied.state_dict().items(): + ptrs[v.data_ptr()].append(k) + + shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1} + + for _, shared_names in shared_ptrs.items(): + reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names} + self.assertEqual( + len(reloaded_ptrs), + 1, + f"The shared pointers are incorrect, found different pointers for keys {shared_names}", + ) + def test_tied_model_weights_key_ignore(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: