Skip to content

Commit

Permalink
Fix padding mask for new architectures (#3228)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes #3227

All models that do **not** make use of group norm, such as
- Wav2Vec 2.0 Large (LV-60)*
- Wav2Vec 2.0 Large (LV-60) + Self Training *

do need this fix IMO to able to correctly run batches through the model. Before this PR, the
following code snippet failed:

```python
import fairseq
import torch

# get model
wav2vec_path = "data/wav2vec2_vox_960h_new.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
    [wav2vec_path], arg_overrides={"data": "./data"}
)
model = model[0]
model.eval()

# create single input
input_wav_0 = torch.randn((1, 2000))
input_wav_1 = torch.randn((1, 3000))

# create batched input
batch_input_wav = torch.zeros((2, 3000))
batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0
batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1

# create padding mask
padding_mask = torch.zeros((2, 3000), dtype=torch.bool)
padding_mask[0, input_wav_0.shape[-1]:] = True

# run batch & single
output = model(source=input_wav_0, padding_mask=None)["encoder_out"]
batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"]

# is equal?
print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3))
```
Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data.

Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility.

This PR should fix the behavior and make the above code snippet / notebook run succesfully.

## PR review

Gently pinging alexeib for Wav2Vec2

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: #3228

Reviewed By: aconneau

Differential Revision: D26373721

Pulled By: alexeib

fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
  • Loading branch information
patrickvonplaten authored and facebook-github-bot committed Feb 10, 2021
1 parent 3aeb8fe commit 4fed0be
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions fairseq/models/wav2vec/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,21 @@ def compute_preds(self, x, y, negatives):

return logits

def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""

def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)

conv_cfg_list = eval(self.cfg.conv_feature_layers)

for i in range(len(conv_cfg_list)):
input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2])

return input_lengths.to(torch.long)

def forward(self, source, padding_mask=None, mask=True, features_only=False):

if self.feature_grad_mult > 0:
Expand All @@ -460,11 +475,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
unmasked_features = features.clone()

if padding_mask is not None:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = self._get_feat_extract_output_lengths(input_lengths)

padding_mask = torch.zeros(
features.shape[:2], dtype=features.dtype, device=features.device
)

# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()

if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
Expand Down

0 comments on commit 4fed0be

Please sign in to comment.