-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hard error when ignoring tensors. (#27484) #29906
Conversation
* [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1128,7 +1127,7 @@ def forward( | |||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING | |||
) | |||
class BertLMHeadModel(BertPreTrainedModel): | |||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] | |||
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this was a bug, predictions
does not exist onthis model, only cls.predictions
.
@@ -1667,15 +1742,19 @@ def tie_encoder_to_decoder_recursively( | |||
module_name: str, | |||
uninitialized_encoder_weights: List[str], | |||
depth=0, | |||
total_decoder_name="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is important since the module_name
is a generic name, and encoder_name
and decoder_name
can differ ( when there's a ignored cross_attn layer in the tying)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As talked offline, now that the name of the encoder is passed, LGTM.
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice handling of something quite tricky - thanks for adding tests! ❤️
Only concern is that we are still vulnerable to _tied_weights
being modified after instance creation, but I don't really see an easy way to prevent this other than giving warnings here.
total_encoder_name=f"{total_encoder_name}.{encoder_name}", | ||
total_decoder_name=f"{total_decoder_name}.{decoder_name}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here - do we want to account for when the string is empty?
total_encoder_name=f"{total_encoder_name}.{encoder_name}", | |
total_decoder_name=f"{total_decoder_name}.{decoder_name}", | |
total_encoder_name=f"{total_encoder_name}.{encoder_name}" if total_encoder_name else encoder_name, | |
total_decoder_name=f"{total_decoder_name}.{decoder_name}" if total_decoder_name else decoder_name, |
): | ||
assert isinstance(decoder_pointer, nn.Module) and isinstance( | ||
encoder_pointer, nn.Module | ||
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" | ||
if hasattr(decoder_pointer, "weight"): | ||
assert hasattr(encoder_pointer, "weight") | ||
encoder_pointer.weight = decoder_pointer.weight | ||
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(not sure at all) but should there be a dot here between the names?
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") | |
tied_weights.append(f"{base_encoder_name}.{total_encoder_name}.weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the encode already has the leading dot from the way the recursive calls are made.
Forcing it here means adding extra logic in the recursive descent.
I can do it to make the code more readable (but in general in such complex code I don't like adding too many ifs especially on dependant varibles in recursive calls)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - I'd rather no if statements
if hasattr(decoder_pointer, "bias"): | ||
assert hasattr(encoder_pointer, "bias") | ||
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and possibly here?
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") | |
tied_weights.append(f"{base_encoder_name}.{total_encoder_name}.bias") |
Those are private therefore it should be OK. You can make it immutable through @Property but that seems a bit too much at this point. |
* Hard error when ignoring tensors. (#27484) * [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add small tests. * Dead variable. * Fixup. * Fixing tied_Weights_keys on generic models. * Fixup + T5 encoder/decoder tying (with different layers) * Code quality. * Dynamic member. * trigger * Fixing encoder name for other types of encoder/decoder combos. * Fix scoping. * Update .github/workflows/self-scheduled.yml Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fixing the tied_weights after the call. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* Hard error when ignoring tensors. (#27484) * [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add small tests. * Dead variable. * Fixup. * Fixing tied_Weights_keys on generic models. * Fixup + T5 encoder/decoder tying (with different layers) * Code quality. * Dynamic member. * trigger * Fixing encoder name for other types of encoder/decoder combos. * Fix scoping. * Update .github/workflows/self-scheduled.yml Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fixing the tied_weights after the call. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
[WIP] Hard error when ignoring tensors.
Better selection/error when saving a checkpoint.
config)
get rid of the sharing before saving)
but we try to find them all anyway.)
disjoint. raise a hard error.
Adding a failing test on
main
that passes here.We don't need to keep the subfolder logic in this test.
Apply suggestions from code review
Co-authored-by: Arthur 48595927+ArthurZucker@users.noreply.github.com
Co-authored-by: Arthur 48595927+ArthurZucker@users.noreply.github.com
Should fix #29903, fixes #28293
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.