-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hard error when ignoring tensors. (#27484) #29906
Changes from all commits
e7e3290
528ad72
5c4aaf5
270e6ea
7cd1593
d561bef
e5bec8f
71f1f67
d915cc5
e7a3186
455b478
a053ec7
7974f97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -30,7 +30,7 @@ | |||||||||
from dataclasses import dataclass | ||||||||||
from functools import partial, wraps | ||||||||||
from threading import Thread | ||||||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | ||||||||||
from zipfile import is_zipfile | ||||||||||
|
||||||||||
import torch | ||||||||||
|
@@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys): | |||||||||
return not_initialized_submodules | ||||||||||
|
||||||||||
|
||||||||||
def _end_ptr(tensor: torch.Tensor) -> int: | ||||||||||
# extract the end of the pointer if the tensor is a slice of a bigger tensor | ||||||||||
if tensor.nelement(): | ||||||||||
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() | ||||||||||
else: | ||||||||||
stop = tensor.data_ptr() | ||||||||||
return stop | ||||||||||
|
||||||||||
|
||||||||||
def _get_tied_weight_keys(module: nn.Module, prefix=""): | ||||||||||
tied_weight_keys = [] | ||||||||||
if getattr(module, "_tied_weights_keys", None) is not None: | ||||||||||
names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys] | ||||||||||
tied_weight_keys.extend(names) | ||||||||||
if getattr(module, "_dynamic_tied_weights_keys", None) is not None: | ||||||||||
names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys] | ||||||||||
tied_weight_keys.extend(names) | ||||||||||
for name, submodule in module.named_children(): | ||||||||||
local_prefix = f"{prefix}.{name}" if prefix else name | ||||||||||
tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix)) | ||||||||||
return tied_weight_keys | ||||||||||
|
||||||||||
|
||||||||||
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]: | ||||||||||
filtered_tensors = [] | ||||||||||
for shared in tensors: | ||||||||||
if len(shared) < 2: | ||||||||||
filtered_tensors.append(shared) | ||||||||||
continue | ||||||||||
|
||||||||||
areas = [] | ||||||||||
for name in shared: | ||||||||||
tensor = state_dict[name] | ||||||||||
areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) | ||||||||||
areas.sort() | ||||||||||
|
||||||||||
_, last_stop, last_name = areas[0] | ||||||||||
filtered_tensors.append({last_name}) | ||||||||||
for start, stop, name in areas[1:]: | ||||||||||
if start >= last_stop: | ||||||||||
filtered_tensors.append({name}) | ||||||||||
else: | ||||||||||
filtered_tensors[-1].add(name) | ||||||||||
last_stop = stop | ||||||||||
disjoint_tensors = [] | ||||||||||
shared_tensors = [] | ||||||||||
for tensors in filtered_tensors: | ||||||||||
if len(tensors) == 1: | ||||||||||
disjoint_tensors.append(tensors.pop()) | ||||||||||
else: | ||||||||||
shared_tensors.append(tensors) | ||||||||||
return shared_tensors, disjoint_tensors | ||||||||||
|
||||||||||
|
||||||||||
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: | ||||||||||
shared_tensors = [] | ||||||||||
identical = [] | ||||||||||
for shared in tensors: | ||||||||||
if len(shared) < 2: | ||||||||||
continue | ||||||||||
|
||||||||||
areas = collections.defaultdict(set) | ||||||||||
for name in shared: | ||||||||||
tensor = state_dict[name] | ||||||||||
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) | ||||||||||
areas[area].add(name) | ||||||||||
if len(areas) == 1: | ||||||||||
identical.append(shared) | ||||||||||
else: | ||||||||||
shared_tensors.append(shared) | ||||||||||
return shared_tensors, identical | ||||||||||
|
||||||||||
|
||||||||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): | ||||||||||
# Convert old format to new format if needed from a PyTorch state_dict | ||||||||||
old_keys = [] | ||||||||||
|
@@ -1646,15 +1719,24 @@ def tie_weights(self): | |||||||||
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): | ||||||||||
if hasattr(self, self.base_model_prefix): | ||||||||||
self = getattr(self, self.base_model_prefix) | ||||||||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) | ||||||||||
tied_weights = self._tie_encoder_decoder_weights( | ||||||||||
self.encoder, self.decoder, self.base_model_prefix, "encoder" | ||||||||||
) | ||||||||||
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class | ||||||||||
# attributed not an instance member, therefore modifying it will modify the entire class | ||||||||||
# Leading to issues on subsequent calls by different tests or subsequent calls. | ||||||||||
self._dynamic_tied_weights_keys = tied_weights | ||||||||||
|
||||||||||
for module in self.modules(): | ||||||||||
if hasattr(module, "_tie_weights"): | ||||||||||
module._tie_weights() | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): | ||||||||||
def _tie_encoder_decoder_weights( | ||||||||||
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str | ||||||||||
): | ||||||||||
uninitialized_encoder_weights: List[str] = [] | ||||||||||
tied_weights: List[str] = [] | ||||||||||
if decoder.__class__ != encoder.__class__: | ||||||||||
logger.info( | ||||||||||
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" | ||||||||||
|
@@ -1665,17 +1747,22 @@ def tie_encoder_to_decoder_recursively( | |||||||||
decoder_pointer: nn.Module, | ||||||||||
encoder_pointer: nn.Module, | ||||||||||
module_name: str, | ||||||||||
base_encoder_name: str, | ||||||||||
uninitialized_encoder_weights: List[str], | ||||||||||
depth=0, | ||||||||||
total_decoder_name="", | ||||||||||
total_encoder_name="", | ||||||||||
): | ||||||||||
assert isinstance(decoder_pointer, nn.Module) and isinstance( | ||||||||||
encoder_pointer, nn.Module | ||||||||||
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" | ||||||||||
if hasattr(decoder_pointer, "weight"): | ||||||||||
assert hasattr(encoder_pointer, "weight") | ||||||||||
encoder_pointer.weight = decoder_pointer.weight | ||||||||||
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (not sure at all) but should there be a dot here between the names?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the encode already has the leading dot from the way the recursive calls are made. Forcing it here means adding extra logic in the recursive descent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed - I'd rather no if statements |
||||||||||
if hasattr(decoder_pointer, "bias"): | ||||||||||
assert hasattr(encoder_pointer, "bias") | ||||||||||
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and possibly here?
Suggested change
|
||||||||||
encoder_pointer.bias = decoder_pointer.bias | ||||||||||
return | ||||||||||
|
||||||||||
|
@@ -1713,19 +1800,26 @@ def tie_encoder_to_decoder_recursively( | |||||||||
decoder_modules[decoder_name], | ||||||||||
encoder_modules[encoder_name], | ||||||||||
module_name + "/" + name, | ||||||||||
base_encoder_name, | ||||||||||
uninitialized_encoder_weights, | ||||||||||
depth=depth + 1, | ||||||||||
total_encoder_name=f"{total_encoder_name}.{encoder_name}", | ||||||||||
total_decoder_name=f"{total_decoder_name}.{decoder_name}", | ||||||||||
Comment on lines
+1806
to
+1807
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here - do we want to account for when the string is empty?
Suggested change
|
||||||||||
) | ||||||||||
all_encoder_weights.remove(module_name + "/" + encoder_name) | ||||||||||
|
||||||||||
uninitialized_encoder_weights += list(all_encoder_weights) | ||||||||||
|
||||||||||
# tie weights recursively | ||||||||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | ||||||||||
tie_encoder_to_decoder_recursively( | ||||||||||
decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights | ||||||||||
) | ||||||||||
|
||||||||||
if len(uninitialized_encoder_weights) > 0: | ||||||||||
logger.warning( | ||||||||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||||||||||
) | ||||||||||
return tied_weights | ||||||||||
|
||||||||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): | ||||||||||
"""Tie or clone module weights depending of whether we are using TorchScript or not""" | ||||||||||
|
@@ -2402,34 +2496,49 @@ def save_pretrained( | |||||||||
|
||||||||||
# These are all the pointers of shared tensors. | ||||||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||||||||||
warn_names = set() | ||||||||||
error_names = [] | ||||||||||
to_delete_names = set() | ||||||||||
# Recursively descend to find tied weight keys | ||||||||||
_tied_weights_keys = _get_tied_weight_keys(self) | ||||||||||
for names in shared_ptrs.values(): | ||||||||||
# Removing the keys which are declared as known duplicates on | ||||||||||
# load. This allows to make sure the name which is kept is consistent. | ||||||||||
if self._tied_weights_keys is not None: | ||||||||||
if _tied_weights_keys is not None: | ||||||||||
found = 0 | ||||||||||
for name in sorted(names): | ||||||||||
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) | ||||||||||
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys) | ||||||||||
if matches_pattern and name in state_dict: | ||||||||||
found += 1 | ||||||||||
if found < len(names): | ||||||||||
del state_dict[name] | ||||||||||
|
||||||||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||||||||||
# If the link between tensors was done at runtime then `from_pretrained` will not get | ||||||||||
# the key back leading to random tensor. A proper warning will be shown | ||||||||||
# during reload (if applicable), but since the file is not necessarily compatible with | ||||||||||
# the config, better show a proper warning. | ||||||||||
found = 0 | ||||||||||
for name in names: | ||||||||||
if name in state_dict: | ||||||||||
found += 1 | ||||||||||
if found > 1: | ||||||||||
del state_dict[name] | ||||||||||
warn_names.add(name) | ||||||||||
if len(warn_names) > 0: | ||||||||||
logger.warning_once( | ||||||||||
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", | ||||||||||
to_delete_names.add(name) | ||||||||||
# We are entering a place where the weights and the transformers configuration do NOT match. | ||||||||||
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) | ||||||||||
# Those are actually tensor sharing but disjoint from each other, we can safely clone them | ||||||||||
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way. | ||||||||||
for name in disjoint_names: | ||||||||||
state_dict[name] = state_dict[name].clone() | ||||||||||
|
||||||||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||||||||||
# If the link between tensors was done at runtime then `from_pretrained` will not get | ||||||||||
# the key back leading to random tensor. A proper warning will be shown | ||||||||||
# during reload (if applicable), but since the file is not necessarily compatible with | ||||||||||
# the config, better show a proper warning. | ||||||||||
shared_names, identical_names = _find_identical(shared_names, state_dict) | ||||||||||
# delete tensors that have identical storage | ||||||||||
for inames in identical_names: | ||||||||||
known = inames.intersection(to_delete_names) | ||||||||||
for name in known: | ||||||||||
del state_dict[name] | ||||||||||
unknown = inames.difference(to_delete_names) | ||||||||||
if len(unknown) > 1: | ||||||||||
error_names.append(unknown) | ||||||||||
|
||||||||||
if shared_names: | ||||||||||
error_names.append(set(shared_names)) | ||||||||||
|
||||||||||
if len(error_names) > 0: | ||||||||||
raise RuntimeError( | ||||||||||
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", | ||||||||||
) | ||||||||||
|
||||||||||
# Shard the model if it is too big. | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
# limitations under the License. | ||
"""PyTorch BERT model.""" | ||
|
||
|
||
import math | ||
import os | ||
import warnings | ||
|
@@ -1128,7 +1127,7 @@ def forward( | |
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING | ||
) | ||
class BertLMHeadModel(BertPreTrainedModel): | ||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] | ||
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this was a bug, |
||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is important since the
module_name
is a generic name, andencoder_name
anddecoder_name
can differ ( when there's a ignored cross_attn layer in the tying)