Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] Fix padding mask for new architectures #3228

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Feb 10, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #3227

All models that do not make use of group norm, such as

  • Wav2Vec 2.0 Large (LV-60)*
  • Wav2Vec 2.0 Large (LV-60) + Self Training *

do need this fix IMO to able to correctly run batches through the model. Before this PR, the
following code snippet failed:

import fairseq
import torch

# get model
wav2vec_path = "data/wav2vec2_vox_960h_new.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
    [wav2vec_path], arg_overrides={"data": "./data"}
)
model = model[0]
model.eval()

# create single input
input_wav_0 = torch.randn((1, 2000))
input_wav_1 = torch.randn((1, 3000))

# create batched input
batch_input_wav = torch.zeros((2, 3000))
batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0
batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1

# create padding mask
padding_mask = torch.zeros((2, 3000), dtype=torch.bool)
padding_mask[0, input_wav_0.shape[-1]:] = True

# run batch & single
output = model(source=input_wav_0, padding_mask=None)["encoder_out"]
batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"]

# is equal?
print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3))

Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data.

Also, see this notebook for reproducibility.

This PR should fix the behavior and make the above code snippet / notebook run succesfully.

PR review

Gently pinging @alexeib for Wav2Vec2

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@patrickvonplaten patrickvonplaten changed the title fix padding mask for wav2vec2 [Wav2Vec2] Fix padding mask for new architectures Feb 10, 2021
Copy link
Contributor

@alexeib alexeib left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good!
overall it probably shouldnt make much difference to transcriptions that end up being output, but still a good fix.

fairseq/models/wav2vec/wav2vec2.py Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexeib has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@alexeib merged this pull request in 4fed0be.

harkash pushed a commit to harkash/fairseq that referenced this pull request Feb 23, 2021
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch#3227

All models that do **not** make use of group norm, such as
- Wav2Vec 2.0 Large (LV-60)*
- Wav2Vec 2.0 Large (LV-60) + Self Training *

do need this fix IMO to able to correctly run batches through the model. Before this PR, the
following code snippet failed:

```python
import fairseq
import torch

# get model
wav2vec_path = "data/wav2vec2_vox_960h_new.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
    [wav2vec_path], arg_overrides={"data": "./data"}
)
model = model[0]
model.eval()

# create single input
input_wav_0 = torch.randn((1, 2000))
input_wav_1 = torch.randn((1, 3000))

# create batched input
batch_input_wav = torch.zeros((2, 3000))
batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0
batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1

# create padding mask
padding_mask = torch.zeros((2, 3000), dtype=torch.bool)
padding_mask[0, input_wav_0.shape[-1]:] = True

# run batch & single
output = model(source=input_wav_0, padding_mask=None)["encoder_out"]
batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"]

# is equal?
print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3))
```
Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data.

Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility.

This PR should fix the behavior and make the above code snippet / notebook run succesfully.

## PR review

Gently pinging alexeib for Wav2Vec2

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: facebookresearch#3228

Reviewed By: aconneau

Differential Revision: D26373721

Pulled By: alexeib

fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
jinyiyang-jhu pushed a commit to jinyiyang-jhu/fairseq-jyang that referenced this pull request Feb 26, 2021
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch/fairseq#3227

All models that do **not** make use of group norm, such as
- Wav2Vec 2.0 Large (LV-60)*
- Wav2Vec 2.0 Large (LV-60) + Self Training *

do need this fix IMO to able to correctly run batches through the model. Before this PR, the
following code snippet failed:

```python
import fairseq
import torch

# get model
wav2vec_path = "data/wav2vec2_vox_960h_new.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
    [wav2vec_path], arg_overrides={"data": "./data"}
)
model = model[0]
model.eval()

# create single input
input_wav_0 = torch.randn((1, 2000))
input_wav_1 = torch.randn((1, 3000))

# create batched input
batch_input_wav = torch.zeros((2, 3000))
batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0
batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1

# create padding mask
padding_mask = torch.zeros((2, 3000), dtype=torch.bool)
padding_mask[0, input_wav_0.shape[-1]:] = True

# run batch & single
output = model(source=input_wav_0, padding_mask=None)["encoder_out"]
batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"]

# is equal?
print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3))
```
Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data.

Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility.

This PR should fix the behavior and make the above code snippet / notebook run succesfully.

## PR review

Gently pinging alexeib for Wav2Vec2

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: facebookresearch/fairseq#3228

Reviewed By: aconneau

Differential Revision: D26373721

Pulled By: alexeib

fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52
@jeffxtang
Copy link

Will this fix resolve the issue #3278? Will the code snippet there be affected by this fix? I'm not sure if fairseq is required to run the code:

import soundfile as sf
import torch
from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer

tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model.eval()
audio_input, _ = sf.read("ionlywishtobealone.wav")
input_values = tokenizer(audio_input, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
print(transcription)
print(input_values.shape)

# pad the 1.42s audio to about 3,5,7,9s:
for n in [3,5,7,9]:
  input_values = tokenizer(audio_input, return_tensors="pt", padding="max_length", max_length=n*10000).input_values
  logits = model(input_values).logits
  predicted_ids = torch.argmax(logits, dim=-1)
  transcription = tokenizer.batch_decode(predicted_ids)[0]
  print(input_values.shape)
  print("padding to {}s: {}\n".format(n, transcription))

@jeffxtang
Copy link

I just uninstalled fairseq and the code above still runs. Any ideas how I can apply the fix to the issue above, if the fix can resolve it?

@patrickvonplaten
Copy link
Contributor Author

As far as I understand it, this fix only worked for the lv60 models and not for the base models. The base models cannot really give the same results for padding vs. non-padding I think

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Wav2Vec2] Possible bug with batched input
4 participants