Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverting Deta cloning mecanism. #22656

Merged
merged 7 commits into from
Apr 24, 2023
Merged

Reverting Deta cloning mecanism. #22656

merged 7 commits into from
Apr 24, 2023

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Apr 7, 2023

What does this PR do?

This one is quite odd.
With the revert the slow test will work (I guess what we care most
about):

from transformers import AutoImageProcessor, DetaForObjectDetection
from PIL import Image
import requests
import torch

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("jozhang97/deta-swin-large")

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)

target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
print(results)

However if I incorporate this:

model = DetaForObjectDetection.from_pretrained("jozhang97/deta-swin-large")
model.save_pretrained("./tmp")
model = DetaForObjectDetection.from_pretrained("./tmp")

Then, the output is garbage again (this isn't using safetensors and is
not linked to the original change).
I even tried to revert the PR that introduced the bug.

The change of output is due to safetensors. I need to thoroughly check this.

This revert will fix the slow PR anyway.

I think something is not properly setup in this model, becuase the
uploaded model seems to have those layers NOT linked (hence the
copy.deepcopy) but the rest of the configuration seems to supposed
to assume they are, hence the issue maybe ?

Fixes #22437 (comment)

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil Narsil requested review from ydshieh and sgugger April 7, 2023 16:02
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 7, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed the example, as well as the 2 currently failing DETA tests, work now with this revert. It's indeed better to have a working version while we still need more time to dive into the root of the issue.

Thank you @Narsil !

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 11, 2023

There is however test_can_use_safetensors failing after this PR. Is this test still relevant (at least while we keep the changes in this PR)

@Narsil
Copy link
Contributor Author

Narsil commented Apr 11, 2023

There is however test_can_use_safetensors failing after this PR. Is this test still relevant (at least while we keep the changes in this PR)

The new code should fix everything.

@sgugger for a new review since the change has evolved quite a bit and is not a simple revert anymore.
Added inline comments in the PR to explain what's going on.

@@ -1768,7 +1768,8 @@ def save_pretrained(
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)
ident = (tensor.data_ptr(), tensor.device, tensor.shape, tensor.stride())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change exists because of multi-gpu setup and potential peculiar sharing of tensors.

Tensors are considered shared, and droppable if and only if they are the exact same tensor.
same ptr, same device, same shape, same stride.

We don't need to handle device meta here I think since trying to save a model with device meta should already be a bug.

# This makes sure even if the pattern covers all names
# that we keep at least 1 copy of the name.
for name in sorted(del_names)[: len(names) - 1]:
del state_dict[name]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix for Deta.

Deta _keys_to_ignore_on_load regexp are a bit too generous and cover, ALL duplicates for some layers.
This code ensures that we keep at least 1 key in the dict in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we did, can_use_safetensors test should crash (because safetensors just refuses straight out shared tensors).

In terms of logic we can delete at most n-1 names from the code and if we deleted less, it would mean that the names are the same (because of the use of set). I could stick to lists if you'd prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of logic, I think we should keep the first one and not the last? Usually tensor sharing is written as tensor_2 = tensor_1 not the opposite.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I guess you somehow saw my deleted message?) I think the current change is fine and so I deleted my previous question

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@Narsil Narsil requested review from sgugger and ydshieh April 11, 2023 09:37
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed again the new changes works for the relevant tests. And LGTM with the explanations.

(except the 2 changes in _load_pretrained_model - but it's because I am not familiar with the codebase here. I think @sgugger would know much better than me)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we tried it your way and it doesn't work. Can we try to use Accelerate to detect the tied weights instead as suggested initially?

# This makes sure even if the pattern covers all names
# that we keep at least 1 copy of the name.
for name in sorted(del_names)[: len(names) - 1]:
del state_dict[name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of logic, I think we should keep the first one and not the last? Usually tensor sharing is written as tensor_2 = tensor_1 not the opposite.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@Narsil
Copy link
Contributor Author

Narsil commented Apr 11, 2023

So we tried it your way and it doesn't work. Can we try to use Accelerate to detect the tied weights instead as suggested initially?

Because find_tied_weights looks at the model, where as here we look at the state_dict, which can be passed directly to the function. In both functions the state_dict is the source of truth, not the model, isn't it ?

We could definitely use find_tied_weights and it would most likely pass the tests, but it wouldn't be exactly looking at the same thing. State dict is what is coming in, find_tied_weights is looking where it's being put on. (in from_pretrained, opposite in save_pretrained). In general they should be the same. But not necessarily always.

For instance, I wonder what happens for buffers.

This will ignore the whole state dict as soon as device_map="auto" or low_cpu_mem_usage=True.

Why ? It seems you're using the hash (via is) in accelerate, I will switch to that since we want entirely shared tensors like in accelerate.

@Narsil
Copy link
Contributor Author

Narsil commented Apr 11, 2023

Why ? It seems you're using the hash (via is) in accelerate, I will switch to that since we want entirely shared tensors like in accelerate.

So actually hash doesn't seem to work either, you can have shared buffer and still different hashes.
I'll try to exhibit a simple example, but deta model_decoder.class_embed.n.bias and class_embed.n.bias do share the buffer, and yet don't have the same hash.

This exhibits the different between find_tied_weights and the state_dict. Here the tensors from the state_dict don't share the hash, while the parameters do on the model, yet the tensors on the state dict do share memory.
In this particular case, using find_tied_weights would work, but that also means the opposite is possible.

@sgugger
Copy link
Collaborator

sgugger commented Apr 11, 2023

In both situations, you have access to the model, and find_tied_weights will give you a list of names that are compatible with the state_dict of the model.

In this particular case, using find_tied_weights would work, but that also means the opposite is possible.

If this situation (the opposite) does not appear in Transformers, let's just use find_tied_weights.

I also would like to drive the point home that safetensors not dealing with shared weights makes it unusable in practice in other libs: see what we have to do here... and we really want to use safetensors. How are we going to convince other users?

@Narsil
Copy link
Contributor Author

Narsil commented Apr 11, 2023

makes it unusable in practice

Why are we even caring about _keys_to_ignore and tie_weights if it's so inconvenient ?
Why are we trying to even find tied weights in accelerate ?
How do we expect to use safetensors for the TF models, since sharing doesn't exist over there ?

@Narsil
Copy link
Contributor Author

Narsil commented Apr 21, 2023

In order to help with ease of use of safetensors by itself I created this PR:

huggingface/safetensors#236

which sorts of mimics what is done here.

However I still think this PR and the mechanism in transformer should be kept, since _keys_to_ignore are very good at hinting which keys we should keep, and which to drop, information which is not available in safetensors directly.
Also modification are shallower here since it doesn't touch state_dict and load_state_dict which the proposed methods to have to change.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for considering shared weights in safetensors directly. I agree it would still be cleaner to have the same kind of mechanism in Transformers. Could you please explain to me once again why the hash check does not work for the first changes in the PR (dropping weights in the checkpoint before passing it to safetensors). I don't think we ever tie weights in Transformers other than just setting the same tensors.

Apart from that, just rebasing on main should be necessary here.

Note that I will rework the constants in future work to have one distinct key for the tied weights (as sometimes they are not tied and we are currently not warning the user if they are missing), but it's orthogonal to this PR.

@Narsil
Copy link
Contributor Author

Narsil commented Apr 21, 2023

Thanks for considering shared weights in safetensors directly. I agree it would still be cleaner to have the same kind of mechanism in Transformers. Could you please explain to me once again why the hash check does not work for the first changes in the PR (dropping weights in the checkpoint before passing it to safetensors). I don't think we ever tie weights in Transformers other than just setting the same tensors.

Mostly this:
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2146

 state_dict = kwargs.pop("state_dict", None)

Users can send a state_dict, not linked to self to this PRs tried to look only at the state_dict, instead of self.
This is indeed a bit of an edge case.

Then there are even further edge cases:

  class Model(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.a = torch.nn.Linear(100, 100)
          self.b = self.a

model = Model()
assert model.a is model.b  # OK !
A = torch.zeros((1000, 100))
a = A[:100]
model.a.weight = nn.Parameter(a)
model.b.weight = model.a.weight
assert model.a is model.b  # Well indeed it's the same parameter, but both are shared with respect to a larger tensor
  class NoSharedModel(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.a = torch.nn.Linear(100, 100)
          self.b = torch.nn.Linear(100, 100)
          
model = NoSharedmodel()
A = torch.zeros((100, 100))
model.a.weight = nn.Parameter(A)
model.b.weight = nn.Parameter(A[:10])

assert model.a.weight is not model.b .weight # A is not B in parameters, however, the underlying tensors are indeed shared

I haven't looked at that deeply when fintune occurs to see if the autograd starts to copy the tensors
During state_dict() will give back a and b as shared tensors, yet the params don't have the same hash.

If you want I could take a look at accelerate shared params function and see if this applies. There's a lot of weird things
when playing super deeply with this. I discovered a lot of behavior with Deta from this PR.

But the biggest reason, really is the optional state_dict whereas accelerate looks directly at the model. Within from_pretrained looking at the model is better in this case since what matters is the users' model rather than the state_dict coming from file (be it pytorch or safetensors)

Apart from that, just rebasing on main should be necessary here.

Note that I will rework the constants in future work to have one distinct key for the tied weights (as sometimes they are not tied and we are currently not warning the user if they are missing), but it's orthogonal to this PR.

Great !

@Narsil
Copy link
Contributor Author

Narsil commented Apr 21, 2023

Seeing the rebase, hash doesn't work on tensors unfortunately:

import torch

A = torch.zeros((10, 10))
B = A[1]
A.untyped_storage().data_ptr() == B.untyped_storage().data_ptr()
hash(A) != hash(B)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Good for me for save_pretrained but in from_pretrained I think it's better to rely on the hash and catch all the situations that happen in Transformers (we do not use slice) while having something that works when the model is on the meta device (which will become the default utlimately) instead of relying on the data pointers and not doing anything when the model is on the meta device.

Comment on lines 2941 to 2962
def _tensor_hash(tensor):
# This is better than `tensor.data_ptr()`
# Since A = torch.zeros((10, 10))
# B = A[2, :]
# Then A.data_ptr() != B.data_ptr()
# But actually the storage is still shared
try:
ptr = tensor.untyped_storage().data_ptr()
except AttributeError:
# Fallback for torch==1.10
try:
ptr = tensor.storage().data_ptr()
except NotImplementedError:
# Fallback for meta storage like in 2.0
ptr = 0
return (ptr, tensor.device)

existing_ptrs = {
_tensor_hash(model_state_dict[k])
for k in loaded_keys
if k in model_state_dict and model_state_dict[k].device != torch.device("meta")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the check is not right as it does not find the tied weights when the model is on the meta device (which is going to be the default ultimately to load without using RAM). The goal is to detect tied parameters in the model in any case, so we can rely on the hash for the model weights (there are no shared slices in Transformers models) for this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I swapped for accelerate here.

@Narsil
Copy link
Contributor Author

Narsil commented Apr 21, 2023

(which will become the default utlimately)

Hurray !!!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bearing with me!

@@ -28,6 +28,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from accelerate.utils.modeling import find_tied_parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just protect this by an is_accelerate_available? If users installed transformers and PyTorch separately, they won't have it (they'd need to do pip install transformers["torch"]) and in this case we just skip the test of missing tied parameters (so there would be maybe extra warning in this case).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation !

@Narsil
Copy link
Contributor Author

Narsil commented Apr 24, 2023

Failing tests seem to be linked to newly release huggingface_hub==0.14.0

@sgugger Merge if you think it's OK, I'm going to not merge given this PR affects core modeling.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one last nit (always explicit tests for bool values and no Python conversion magic as usual ;-) ) and it should be good to merge.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
Narsil and others added 7 commits April 24, 2023 17:06
This is more correct there, since it handles meta device seemlessly
and we don't need to handle "non-duplicate" tensors (slices of each
other).
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger sgugger merged commit 6e32959 into huggingface:main Apr 24, 2023
@Narsil Narsil deleted the revert_deta branch April 24, 2023 15:40
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* Fixed the revert by making sure that even the regexp can cover all
duplicates.

* Code simplification using hash.

* Fixing the `ident`.

* Fixing ignoring patterened duplicate names.

* Using `accelerate@find_tied_parameters` for from_pretrained

This is more correct there, since it handles meta device seemlessly
and we don't need to handle "non-duplicate" tensors (slices of each
other).

* Protecting accelerate.

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Fixed the revert by making sure that even the regexp can cover all
duplicates.

* Code simplification using hash.

* Fixing the `ident`.

* Fixing ignoring patterened duplicate names.

* Using `accelerate@find_tied_parameters` for from_pretrained

This is more correct there, since it handles meta device seemlessly
and we don't need to handle "non-duplicate" tensors (slices of each
other).

* Protecting accelerate.

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants