Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hard error when ignoring tensors. (#27484) #29906

Merged
merged 13 commits into from
Apr 2, 2024
Merged

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Mar 27, 2024

  • [WIP] Hard error when ignoring tensors.

  • Better selection/error when saving a checkpoint.

  • Find all names we should normally drop (those are in the transformers
    config)
  • Find all disjoint tensors (for those we can safely trigger a copy to
    get rid of the sharing before saving)
  • Clone those disjoint tensors getting rid of the issue
  • Find all identical names (those should be declared in the config
    but we try to find them all anyway.)
  • For all identical names:
    • If they are in the config, just ignore them everything is fine
    • If they are not, warn about them.
  • For all remainder tensors which are shared yet neither identical NOR
    disjoint. raise a hard error.
  • Adding a failing test on main that passes here.

  • We don't need to keep the subfolder logic in this test.

  • Apply suggestions from code review

Co-authored-by: Arthur 48595927+ArthurZucker@users.noreply.github.com


Co-authored-by: Arthur 48595927+ArthurZucker@users.noreply.github.com

Should fix #29903, fixes #28293

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Narsil and others added 4 commits March 27, 2024 15:02
* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -1128,7 +1127,7 @@ def forward(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this was a bug, predictions does not exist onthis model, only cls.predictions.

@@ -1667,15 +1742,19 @@ def tie_encoder_to_decoder_recursively(
module_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
total_decoder_name="",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important since the module_name is a generic name, and encoder_name and decoder_name can differ ( when there's a ignored cross_attn layer in the tying)

@Narsil Narsil requested a review from ydshieh March 27, 2024 18:10
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As talked offline, now that the name of the encoder is passed, LGTM.

.github/workflows/self-scheduled.yml Outdated Show resolved Hide resolved
Narsil and others added 3 commits March 29, 2024 22:44
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice handling of something quite tricky - thanks for adding tests! ❤️

Only concern is that we are still vulnerable to _tied_weights being modified after instance creation, but I don't really see an easy way to prevent this other than giving warnings here.

Comment on lines +1806 to +1807
total_encoder_name=f"{total_encoder_name}.{encoder_name}",
total_decoder_name=f"{total_decoder_name}.{decoder_name}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here - do we want to account for when the string is empty?

Suggested change
total_encoder_name=f"{total_encoder_name}.{encoder_name}",
total_decoder_name=f"{total_decoder_name}.{decoder_name}",
total_encoder_name=f"{total_encoder_name}.{encoder_name}" if total_encoder_name else encoder_name,
total_decoder_name=f"{total_decoder_name}.{decoder_name}" if total_decoder_name else decoder_name,

):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not sure at all) but should there be a dot here between the names?

Suggested change
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
tied_weights.append(f"{base_encoder_name}.{total_encoder_name}.weight")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the encode already has the leading dot from the way the recursive calls are made.

Forcing it here means adding extra logic in the recursive descent.
I can do it to make the code more readable (but in general in such complex code I don't like adding too many ifs especially on dependant varibles in recursive calls)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I'd rather no if statements

if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and possibly here?

Suggested change
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
tied_weights.append(f"{base_encoder_name}.{total_encoder_name}.bias")

@Narsil
Copy link
Contributor Author

Narsil commented Apr 2, 2024

Only concern is that we are still vulnerable to _tied_weights being modified after instance creation, but I don't really see an easy way to prevent this other than giving warnings here.

Those are private therefore it should be OK. You can make it immutable through @Property but that seems a bit too much at this point.

@Narsil Narsil merged commit 9b0a8ea into main Apr 2, 2024
21 checks passed
@Narsil Narsil deleted the hard_error_safetensors branch April 2, 2024 14:59
ArthurZucker added a commit that referenced this pull request Apr 22, 2024
* Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add small tests.

* Dead variable.

* Fixup.

* Fixing tied_Weights_keys on generic models.

* Fixup + T5 encoder/decoder tying (with different layers)

* Code quality.

* Dynamic member.

* trigger

* Fixing encoder name for other types of encoder/decoder combos.

* Fix scoping.

* Update .github/workflows/self-scheduled.yml

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fixing the tied_weights after the call.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
itazap pushed a commit that referenced this pull request May 14, 2024
* Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add small tests.

* Dead variable.

* Fixup.

* Fixing tied_Weights_keys on generic models.

* Fixup + T5 encoder/decoder tying (with different layers)

* Code quality.

* Dynamic member.

* trigger

* Fixing encoder name for other types of encoder/decoder combos.

* Fix scoping.

* Update .github/workflows/self-scheduled.yml

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fixing the tied_weights after the call.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants