-
Notifications
You must be signed in to change notification settings - Fork 6.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2] Fix padding mask for new architectures #3228
[Wav2Vec2] Fix padding mask for new architectures #3228
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looks good!
overall it probably shouldnt make much difference to transcriptions that end up being output, but still a good fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexeib has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes facebookresearch#3227 All models that do **not** make use of group norm, such as - Wav2Vec 2.0 Large (LV-60)* - Wav2Vec 2.0 Large (LV-60) + Self Training * do need this fix IMO to able to correctly run batches through the model. Before this PR, the following code snippet failed: ```python import fairseq import torch # get model wav2vec_path = "data/wav2vec2_vox_960h_new.pt" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [wav2vec_path], arg_overrides={"data": "./data"} ) model = model[0] model.eval() # create single input input_wav_0 = torch.randn((1, 2000)) input_wav_1 = torch.randn((1, 3000)) # create batched input batch_input_wav = torch.zeros((2, 3000)) batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0 batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1 # create padding mask padding_mask = torch.zeros((2, 3000), dtype=torch.bool) padding_mask[0, input_wav_0.shape[-1]:] = True # run batch & single output = model(source=input_wav_0, padding_mask=None)["encoder_out"] batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"] # is equal? print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3)) ``` Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data. Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility. This PR should fix the behavior and make the above code snippet / notebook run succesfully. ## PR review Gently pinging alexeib for Wav2Vec2 Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: facebookresearch#3228 Reviewed By: aconneau Differential Revision: D26373721 Pulled By: alexeib fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes facebookresearch/fairseq#3227 All models that do **not** make use of group norm, such as - Wav2Vec 2.0 Large (LV-60)* - Wav2Vec 2.0 Large (LV-60) + Self Training * do need this fix IMO to able to correctly run batches through the model. Before this PR, the following code snippet failed: ```python import fairseq import torch # get model wav2vec_path = "data/wav2vec2_vox_960h_new.pt" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [wav2vec_path], arg_overrides={"data": "./data"} ) model = model[0] model.eval() # create single input input_wav_0 = torch.randn((1, 2000)) input_wav_1 = torch.randn((1, 3000)) # create batched input batch_input_wav = torch.zeros((2, 3000)) batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0 batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1 # create padding mask padding_mask = torch.zeros((2, 3000), dtype=torch.bool) padding_mask[0, input_wav_0.shape[-1]:] = True # run batch & single output = model(source=input_wav_0, padding_mask=None)["encoder_out"] batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"] # is equal? print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3)) ``` Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data. Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility. This PR should fix the behavior and make the above code snippet / notebook run succesfully. ## PR review Gently pinging alexeib for Wav2Vec2 Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: facebookresearch/fairseq#3228 Reviewed By: aconneau Differential Revision: D26373721 Pulled By: alexeib fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
Will this fix resolve the issue #3278? Will the code snippet there be affected by this fix? I'm not sure if fairseq is required to run the code:
|
I just uninstalled fairseq and the code above still runs. Any ideas how I can apply the fix to the issue above, if the fix can resolve it? |
As far as I understand it, this fix only worked for the |
Before submitting
What does this PR do?
Fixes #3227
All models that do not make use of group norm, such as
do need this fix IMO to able to correctly run batches through the model. Before this PR, the
following code snippet failed:
Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data.
Also, see this notebook for reproducibility.
This PR should fix the behavior and make the above code snippet / notebook run succesfully.
PR review
Gently pinging @alexeib for Wav2Vec2
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃