Skip to content

Commit

Permalink
Reverting Deta cloning mecanism. (huggingface#22656)
Browse files Browse the repository at this point in the history
* Fixed the revert by making sure that even the regexp can cover all
duplicates.

* Code simplification using hash.

* Fixing the `ident`.

* Fixing ignoring patterened duplicate names.

* Using `accelerate@find_tied_parameters` for from_pretrained

This is more correct there, since it handles meta device seemlessly
and we don't need to handle "non-duplicate" tensors (slices of each
other).

* Protecting accelerate.

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and novice03 committed Jun 23, 2023
1 parent a6b5154 commit 10c258a
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import (
find_tied_parameters,
load_offloaded_weights,
offload_weight,
save_offload_index,
Expand All @@ -93,6 +94,8 @@
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
else:
find_tied_parameters = None

if is_safetensors_available():
from safetensors import safe_open
Expand Down Expand Up @@ -1776,7 +1779,8 @@ def save_pretrained(
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)
ident = (tensor.data_ptr(), tensor.device, tensor.shape, tensor.stride())
ptrs[ident].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
Expand All @@ -1785,10 +1789,13 @@ def save_pretrained(
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._keys_to_ignore_on_load_missing is not None:
for name in names:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
if matches_pattern and name in state_dict:
del state_dict[name]
found += 1
if found < len(names):
del state_dict[name]

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
Expand Down Expand Up @@ -2934,12 +2941,24 @@ def _fix_key(key):
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

# Some tensors maybe have been already filled by another key (tied weights).
# TODO: Sylvain -> make this work even on meta device.
# existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict}
# missing_keys = [
# k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs
# ]
if find_tied_parameters is not None:
tied_params = find_tied_parameters(model)
else:
tied_params = []
_missing = []
for k in missing_keys:
found = False
for group in tied_params:
if k in group:
found = True
if len(group) > 2:
group.remove(k)
else:
_missing.append(k)
if not found:
_missing.append(k)
missing_keys = _missing

# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
Expand Down

0 comments on commit 10c258a

Please sign in to comment.