Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix padding mask for new architectures (#3228)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes #3227 All models that do **not** make use of group norm, such as - Wav2Vec 2.0 Large (LV-60)* - Wav2Vec 2.0 Large (LV-60) + Self Training * do need this fix IMO to able to correctly run batches through the model. Before this PR, the following code snippet failed: ```python import fairseq import torch # get model wav2vec_path = "data/wav2vec2_vox_960h_new.pt" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [wav2vec_path], arg_overrides={"data": "./data"} ) model = model[0] model.eval() # create single input input_wav_0 = torch.randn((1, 2000)) input_wav_1 = torch.randn((1, 3000)) # create batched input batch_input_wav = torch.zeros((2, 3000)) batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0 batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1 # create padding mask padding_mask = torch.zeros((2, 3000), dtype=torch.bool) padding_mask[0, input_wav_0.shape[-1]:] = True # run batch & single output = model(source=input_wav_0, padding_mask=None)["encoder_out"] batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"] # is equal? print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3)) ``` Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data. Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility. This PR should fix the behavior and make the above code snippet / notebook run succesfully. ## PR review Gently pinging alexeib for Wav2Vec2 Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: #3228 Reviewed By: aconneau Differential Revision: D26373721 Pulled By: alexeib fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
- Loading branch information