Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2 - MMS] Correct directly loading adapters weights #24335

Merged
merged 7 commits into from
Jun 20, 2023

Conversation

patrickvonplaten
Copy link
Contributor

What does this PR do?

This PR corrects incorrect behavior when loading MMS with non-default adapter weights via from_pretrained(...). The issue is explained well here.

In a nutshell, we cannot load specific weights in the init because these loaded weights are later overwritten again in from_pretrained. To solve this I propose to add a new generic

load_adaptive_weights()

call to from_pretrained that can be overridden by models that inherit from PretrainedModel. This both solves the issue #24223
and is also cleaner IMO since weights shouldn't be loaded when calling the __init__ method of a model anyways really. It was weird before that:

model = Wav2Vec2ForCTC(config, target_lang="fra")

would try to load weights into the model.

cc @sgugger @sanchit-gandhi @amyeroberts wdyt about the design? Happy to add some more tests if ok for you

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 17, 2023

The documentation is not available anymore as the PR was closed or merged.

Load adaptive weights after state dict has been loaded. If required this method should be overridden by derived
class.
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd raise an exception here if it's called, otherwise it fails silently

Suggested change
pass
raise NotImplementedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah not 100% sure about the design here, but it seems much more in line with:

Think something like if hasattr(self, "load_adaptive_weights") is also not great.

Also it is a bit questionable whether load_adaptive_weights is general enough to warrant to be in modeling_utils.py, but there is no other way really for the functionality from_pretrained(..., target_lang="...")

@LysandreJik @sgugger wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would override the _tie_encoder_decoder_weights method in Wav2Vec2 only and not add those changes in modeling_utils. If things change and we get lots of models with adapaters, we can revisit the decision and do something like this, but I'd wait for it to be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty hacky, but yeah ok for me. Will add a big comment that we do this to not introduce a new API

@amyeroberts
Copy link
Collaborator

Sorry, I accidentally submitted the review without a saved comment. I realised in the from_pretrained call why you were using pass. I still think raising an exception would be good, as otherwise we can get silent behaviour. Would it be possible to reliably check if load_adaptive_weights should be implemented for a model?

p.s. ignoring the wandb diffs, as they're just from being out-of-date from main

Load adaptive weights after state dict has been loaded. If required this method should be overridden by derived
class.
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would override the _tie_encoder_decoder_weights method in Wav2Vec2 only and not add those changes in modeling_utils. If things change and we get lots of models with adapaters, we can revisit the decision and do something like this, but I'd wait for it to be necessary.

src/transformers/models/wav2vec2/modeling_wav2vec2.py Outdated Show resolved Hide resolved
While slighly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is ok to repurpose this
function here.

This method is **not** supposed to be called by the user and is prone to be changed in the future.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm tie_weights is a public API of PreTrainedModel and we do recommend to users to call it in Accelerate (if they load the model manually instead of using from_pretrained).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overwriting _tie_encoder_decoder_weights doesn't work really though as Wav2Vec2 is not an encoder decoder model

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to force config.is_encoder_decoder=True for MMS models - that is really not correct

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is not what I'm saying. I'm just commenting on this line that will appear in the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per comment, I think it's fine though because Wav2Vec2 will never tie input & output weights weights. So even if the user calls it, it's ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps moving the comment to the docstring would help clarify this?

Copy link
Collaborator

@sgugger sgugger Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that with the doc setup as they are this function won't be present in the documentation so we are actually debating for nothing. 😅 (only forward is set in the methods to document)

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes @patrickvonplaten. The test is convincing for me and don't mind about the design with overriding .tie_weights since it won't be used anyway for W2V2 - will let you finalise this with Sylvain! Overall, much prefer this design to loading weights in the __init__

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for fixing and iterating!


logits_2 = get_logits(model_2, input_features)

self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, I'm surprise the tolerance is so high given we're loading the same weights into the same model 👀

@patrickvonplaten patrickvonplaten merged commit b0513b0 into main Jun 20, 2023
@patrickvonplaten patrickvonplaten deleted the correct_direct_lang_loading branch June 20, 2023 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants