-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2 - MMS] Correct directly loading adapters weights #24335
Conversation
The documentation is not available anymore as the PR was closed or merged. |
src/transformers/modeling_utils.py
Outdated
Load adaptive weights after state dict has been loaded. If required this method should be overridden by derived | ||
class. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd raise an exception here if it's called, otherwise it fails silently
pass | |
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah not 100% sure about the design here, but it seems much more in line with:
transformers/src/transformers/modeling_utils.py
Line 1257 in c5454eb
def tie_weights(self):
andtransformers/src/transformers/modeling_utils.py
Line 1242 in c5454eb
def _init_weights(self, module):
Think something like if hasattr(self, "load_adaptive_weights")
is also not great.
Also it is a bit questionable whether load_adaptive_weights
is general enough to warrant to be in modeling_utils.py
, but there is no other way really for the functionality from_pretrained(..., target_lang="...")
@LysandreJik @sgugger wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I would override the _tie_encoder_decoder_weights
method in Wav2Vec2 only and not add those changes in modeling_utils. If things change and we get lots of models with adapaters, we can revisit the decision and do something like this, but I'd wait for it to be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's pretty hacky, but yeah ok for me. Will add a big comment that we do this to not introduce a new API
Sorry, I accidentally submitted the review without a saved comment. I realised in the p.s. ignoring the wandb diffs, as they're just from being out-of-date from main |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
src/transformers/modeling_utils.py
Outdated
Load adaptive weights after state dict has been loaded. If required this method should be overridden by derived | ||
class. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I would override the _tie_encoder_decoder_weights
method in Wav2Vec2 only and not add those changes in modeling_utils. If things change and we get lots of models with adapaters, we can revisit the decision and do something like this, but I'd wait for it to be necessary.
While slighly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is ok to repurpose this | ||
function here. | ||
|
||
This method is **not** supposed to be called by the user and is prone to be changed in the future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm tie_weights
is a public API of PreTrainedModel
and we do recommend to users to call it in Accelerate (if they load the model manually instead of using from_pretrained
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overwriting _tie_encoder_decoder_weights
doesn't work really though as Wav2Vec2 is not an encoder decoder model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to force config.is_encoder_decoder=True
for MMS models - that is really not correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is not what I'm saying. I'm just commenting on this line that will appear in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per comment, I think it's fine though because Wav2Vec2 will never tie input & output weights weights. So even if the user calls it, it's ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps moving the comment to the docstring would help clarify this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that with the doc setup as they are this function won't be present in the documentation so we are actually debating for nothing. 😅 (only forward is set in the methods to document)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes @patrickvonplaten. The test is convincing for me and don't mind about the design with overriding .tie_weights
since it won't be used anyway for W2V2 - will let you finalise this with Sylvain! Overall, much prefer this design to loading weights in the __init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for fixing and iterating!
|
||
logits_2 = get_logits(model_2, input_features) | ||
|
||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh, I'm surprise the tolerance is so high given we're loading the same weights into the same model 👀
What does this PR do?
This PR corrects incorrect behavior when loading MMS with non-default adapter weights via
from_pretrained(...)
. The issue is explained well here.In a nutshell, we cannot load specific weights in the init because these loaded weights are later overwritten again in
from_pretrained
. To solve this I propose to add a new genericload_adaptive_weights()
call to
from_pretrained
that can be overridden by models that inherit fromPretrainedModel
. This both solves the issue #24223and is also cleaner IMO since weights shouldn't be loaded when calling the
__init__
method of a model anyways really. It was weird before that:would try to load weights into the model.
cc @sgugger @sanchit-gandhi @amyeroberts wdyt about the design? Happy to add some more tests if ok for you