-
Notifications
You must be signed in to change notification settings - Fork 665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wav2vec2 output is affected by zero-padding #2242
Comments
Just to check that the
Which gives |
Hi @JackRAHealth Thanks for the report.et me look into it. |
The given analysis seems to be correct and the proper solution would be implement the normalization that is aware of masking. There are tests for batch consistency but they only use the samples with the similar lengths, so this effect was not caught. audio/test/torchaudio_unittest/models/wav2vec2/model_test.py Lines 139 to 147 in cbf1b83
We need to update two modules ( audio/torchaudio/models/wav2vec2/components.py Lines 520 to 530 in cbf1b83
|
I am un-assigning myself, as this turned out to require more resources than I have at the moment. |
good pickup @JackRAHealth, keeping it 55th street |
This issue seems very serious!!! as the underlying problem comes from nn.GroupNorm of PyTorch... As far as I know, Huggingface's Wav2vec 2.0 model is also implemented with nn.GroupNorm. |
The original fariseq implementation also uses nn.GroupNorm at the core, so this was an issue from the very beginning. cc @cpuhrsch |
cc @jbschlosser who is the TL for NestedTensor |
This problem also exists in HF implementation:
The output will depend on the size of zero pad. It follows that batch size also affect logits and loss. And from this it follows that the validation loss will depend on samples order and batch size, which is inconsistent. |
Is there any update on this issue? Just spent 6 hours in pinpointing this bug in a large codebase. |
Does anyone know any workarounds? i am using MMS_FA bundle for alignment |
🐛 Describe the bug
I've found that the output of the wav2vec2 pipeline model is bugged, and changes depending on the zero-padding used in batch preprocessing, a simple example Is as follows:
Which gives the output of
tensor(68.1875, grad_fn=<CopyBackwards>)
. Changing the value of N2 will change this value further. I've found the source to be the group norm layer after the first convolution in the feature extractor, as it applies group norm across the whole sequence irrespective of it being padding. To amend this, I've created a masked group norm function to only apply normalisation across the actual sequence.This can be added to the model by overloading the preexisting group norm layer, whilst copying over the group norm parameters from the pretrained model. This also requires a new forward call for the model.
Which gives the output of
tensor(5.6603e-05, grad_fn=<CopyBackwards>)
Versions
Collecting environment information...
PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (GCC) 10.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-97-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0
[pip3] torchvision==0.11.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.19.5 pypi_0 pypi
[conda] numpy-base 1.20.3 py39h74d4b33_0
[conda] pytorch 1.10.0 py3.9_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 0.10.0 py39_cu113 pytorch
[conda] torchvision 0.11.1 py39_cu113 pytorch
The text was updated successfully, but these errors were encountered: