Skip to content

Commit

Permalink
Making sure we can use safetensors to serialize all the time. (huggin…
Browse files Browse the repository at this point in the history
…gface#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and raghavanone committed Apr 5, 2023
1 parent b37668d commit 3c815cb
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 6 deletions.
40 changes: 40 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,41 @@ def save_pretrained(
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._keys_to_ignore_on_load_missing is not None:
for name in names:
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
if matches_pattern and name in state_dict:
del state_dict[name]

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)

# Shard the model if it is too big.
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
Expand Down Expand Up @@ -2813,6 +2848,11 @@ def _fix_key(key):
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

# Some tensors maybe have been already filled by another key (tied weights).
existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict}
missing_keys = [
k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs
]
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
Expand Down
24 changes: 22 additions & 2 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,28 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)

def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()

def get_encoder(self):
return self.language_model.get_encoder()

def get_decoder(self):
return self.language_model.get_decoder()

def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared

@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
def get_text_features(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):


def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
return nn.ModuleList([module for i in range(N)])


def inverse_sigmoid(x, eps=1e-5):
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,6 @@ def custom_forward(*inputs):


class LlamaForCausalLM(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,10 @@ def __init__(
initializer_factor=1.0,
initializer_range=0.02,
is_vqa=False,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

if text_config is None:
text_config = {}
Expand Down
36 changes: 36 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import unittest
import unittest.mock as mock
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -1626,6 +1627,41 @@ def check_same_values(layer_1, layer_2):
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
model_tied.save_pretrained(d, safe_serialization=True)
except Exception as e:
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")

model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model_tied.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)

# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)

shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}

for _, shared_names in shared_ptrs.items():
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
self.assertEqual(
len(reloaded_ptrs),
1,
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)

def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
Expand Down

0 comments on commit 3c815cb

Please sign in to comment.