Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2vec2 output is affected by zero-padding #2242

Open
JackPfizer opened this issue Feb 16, 2022 · 11 comments
Open

Wav2vec2 output is affected by zero-padding #2242

JackPfizer opened this issue Feb 16, 2022 · 11 comments

Comments

@JackPfizer
Copy link

JackPfizer commented Feb 16, 2022

🐛 Describe the bug

I've found that the output of the wav2vec2 pipeline model is bugged, and changes depending on the zero-padding used in batch preprocessing, a simple example Is as follows:

import torchaudio as ta, torch
from torch import Tensor, nn
import types
from typing import Optional, Tuple
torch.manual_seed(1)
model = ta.pipelines.WAV2VEC2_BASE.get_model()

N1=11000
dummy_data1 = torch.randn([1,N1])

output1 = model(dummy_data1,lengths=torch.tensor([N1]))

N2=22000
dummy_data2 = torch.randn([1,N2])
dummy_data = torch.zeros([2,N2])
dummy_data[0,:N1] = dummy_data1
dummy_data[1,:N2] = dummy_data2

output2 = model(dummy_data,lengths=torch.tensor([N1,N2]))

frames1 = output1[1][0]
print(torch.norm(output1[0][0,:frames1]-output2[0][0,:frames1]))

Which gives the output of tensor(68.1875, grad_fn=<CopyBackwards>). Changing the value of N2 will change this value further. I've found the source to be the group norm layer after the first convolution in the feature extractor, as it applies group norm across the whole sequence irrespective of it being padding. To amend this, I've created a masked group norm function to only apply normalisation across the actual sequence.

def lengths_to_mask(lengths, max_len=None, dtype=None):
    """
    Converts a "lengths" tensor to its binary mask representation.
    
    Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397
     
    :lengths: N-dimensional tensor
    :returns: N*max_len dimensional tensor. If max_len==None, max_len=max(lengtsh)
    """
    assert len(lengths.shape) == 1, 'Length shape should be 1 dimensional.'
    max_len = max_len or lengths.max().item()
    mask = torch.arange(
        max_len,
        device=lengths.device,
        dtype=lengths.dtype)\
    .expand(len(lengths), max_len) < lengths.unsqueeze(1)
    if dtype is not None:
        mask = torch.as_tensor(mask, dtype=dtype, device=lengths.device)
    return mask

class MaskedGroupNorm(nn.GroupNorm):
    """
    Masked verstion of the Group normalization.
    
    Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
    
    Receives a N-dim tensor of sequence lengths per batch element
    along with the regular input for masking.
    
    Check pytorch's GroupNorm implementation for argument details.
    """
    def __init__(self, num_groups, num_features, eps=1e-5,affine=True):
        super(MaskedGroupNorm, self).__init__(
            num_groups,
            num_features,
            eps,
            affine
        )

    def forward(self, inp, lengths):
        
        # We transform the mask into a sort of P(inp) with equal probabilities
        # for all unmasked elements of the tensor, and 0 probability for masked
        # ones.
        
        assert inp.shape[1]%self.num_groups == 0, 'Feature size not divisible by groups'

        mask = lengths_to_mask(lengths, max_len=inp.shape[-1], dtype=inp.dtype)
        ave_mask = mask / lengths[:,None] / (inp.shape[-2] / self.num_groups) #also features
        ave_mask = ave_mask.unsqueeze(1)#.expand(inp.shape)

        # Here lies the trick. Using Var(X) = E[X^2] - E[X]^2 as the biased
        # variance, we do not need to make any tensor shape manipulation.
        # mean = E[X] is simply the sum-product of our "probability" mask with the input...
        inp = inp*mask.unsqueeze(1) #mask out any extra bits of data - such as those left from conv bleeding
        inp_r = inp.reshape([inp.shape[0],self.num_groups,-1,inp.shape[-1]])
        ave_mask = ave_mask.unsqueeze(2)
        mean = (ave_mask * inp_r).sum([2, 3])
        # ...whereas Var(X) is directly derived from the above formulae
        # This should be numerically equivalent to the biased sample variance
        var = (ave_mask * inp_r ** 2).sum([2, 3]) - mean ** 2

        inp_r = (inp_r - mean[:,:,None,None]) / (torch.sqrt(var[:, :, None, None] + self.eps))
        out = inp_r.reshape(inp.shape)
        if self.affine:
            out = out * self.weight[None, :, None] + self.bias[None, :, None]
        return out * mask.unsqueeze(1)

def masked_conv_forward(
        self,
        x: Tensor,
        length: Optional[Tensor],
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """
        A method to overwrite the wav2vec2 forward function in its feature_extractor.conv_layers[0]
        as it performs differently when there  is extra zero padding
        Args:
            x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
            length (Tensor or None, optional): Shape ``[batch, ]``.
        Returns:
            Tensor: Shape ``[batch, out_channels, out_frames]``.
            Optional[Tensor]: Shape ``[batch, ]``.
        """
        x = self.conv(x)

        if length is not None:
            length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
            # When input length is 0, the resulting length can be negative. So fix it here.
            length = torch.max(torch.zeros_like(length), length)

        if self.layer_norm is not None:
            if isinstance(self.layer_norm, MaskedGroupNorm):
                x = self.layer_norm(x,length)
            else:
                x = self.layer_norm(x)
        x = nn.functional.gelu(x)

        return x, length

This can be added to the model by overloading the preexisting group norm layer, whilst copying over the group norm parameters from the pretrained model. This also requires a new forward call for the model.

prior_params = vars(model.feature_extractor.conv_layers[0].layer_norm)
model.feature_extractor.conv_layers[0].layer_norm = MaskedGroupNorm(model.feature_extractor.conv_layers[0].layer_norm.num_groups,
                                                                    model.feature_extractor.conv_layers[0].layer_norm.num_channels)
model.feature_extractor.conv_layers[0].layer_norm .__dict__.update(prior_params)
model.feature_extractor.conv_layers[0].forward = types.MethodType(masked_conv_forward, model.feature_extractor.conv_layers[0])



output1 = model(dummy_data1,lengths=torch.tensor([N1]))

output2 = model(dummy_data,lengths=torch.tensor([N1,N2]))

print(torch.norm(output1[0][0,:frames1]-output2[0][0,:frames1]))

Which gives the output of tensor(5.6603e-05, grad_fn=<CopyBackwards>)

Versions

Collecting environment information...
PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (GCC) 10.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-97-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070

Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0
[pip3] torchvision==0.11.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.19.5 pypi_0 pypi
[conda] numpy-base 1.20.3 py39h74d4b33_0
[conda] pytorch 1.10.0 py3.9_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 0.10.0 py39_cu113 pytorch
[conda] torchvision 0.11.1 py39_cu113 pytorch

@JackPfizer
Copy link
Author

JackPfizer commented Feb 16, 2022

Just to check that the masked_group_norm function is working, I tested the first output1 against the new model's output prediction:

#output1 is based on the existing ta pipeline, output1_1 is based on my amendments

output1_1 = model(dummy_data1,lengths=torch.tensor([N1]))

print(torch.norm(output1[0][0,:frames1]-output1_1[0][0,:frames1]))

Which gives tensor(6.2804e-05, grad_fn=<CopyBackwards>). So it's not perfect, but is reasonably close.

@mthrok mthrok self-assigned this Feb 16, 2022
@mthrok mthrok added the bug label Feb 16, 2022
@mthrok mthrok added this to the v0.11 milestone Feb 16, 2022
@mthrok
Copy link
Collaborator

mthrok commented Feb 16, 2022

Hi @JackRAHealth

Thanks for the report.et me look into it.

@mthrok
Copy link
Collaborator

mthrok commented Feb 21, 2022

The given analysis seems to be correct and the proper solution would be implement the normalization that is aware of masking.

There are tests for batch consistency but they only use the samples with the similar lengths, so this effect was not caught.

@factory_funcs
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func())
@factory_funcs
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func(aux_num_out=32))

We need to update two modules (torch.nn.GroupNorm and LayerNorm) so that they support masking.

if norm_mode == "group_norm" and i == 0:
normalization = nn.GroupNorm(
num_groups=out_channels,
num_channels=out_channels,
affine=True,
)
elif norm_mode == "layer_norm":
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)

@mthrok mthrok removed this from the v0.11 milestone Feb 21, 2022
@mthrok
Copy link
Collaborator

mthrok commented Feb 21, 2022

I am un-assigning myself, as this turned out to require more resources than I have at the moment.
If anyone is interested in resolving this, let us know. (please do discuss some detail before making a PR.)

@mthrok mthrok removed their assignment Feb 21, 2022
@anicolson
Copy link

good pickup @JackRAHealth, keeping it 55th street

@kss2517
Copy link

kss2517 commented Jul 22, 2022

This issue seems very serious!!! as the underlying problem comes from nn.GroupNorm of PyTorch...
Its scope goes beyond Torchaudio.

As far as I know, Huggingface's Wav2vec 2.0 model is also implemented with nn.GroupNorm.
It means that recent works might have wrongly reported their performances by using Wav2vec 2.0 model of Torchaudio or Huggingface as the baseline.

@mthrok
Copy link
Collaborator

mthrok commented Jul 22, 2022

As far as I know, Huggingface's Wav2vec 2.0 model is also implemented with nn.GroupNorm.
It means that recent works might have wrongly reported their performances by using Wav2vec 2.0 model of Torchaudio or Huggingface as the baseline.

The original fariseq implementation also uses nn.GroupNorm at the core, so this was an issue from the very beginning.
We are thinking that the adaptation of NestedTensor is a way to solve this.

cc @cpuhrsch

@cpuhrsch
Copy link
Contributor

cc @jbschlosser who is the TL for NestedTensor

@sedol1339
Copy link

sedol1339 commented Aug 8, 2024

This problem also exists in HF implementation:

from transformers import Wav2Vec2ForCTC
import torch
import numpy as np

device = 'cpu'
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h').to(device)

torch.manual_seed(0)
waveform = torch.rand(1, 100_000).to(device)
zero_pad = torch.zeros((1, 1_000)).to(device)

model.eval()
with torch.no_grad():
    outputs = model.wav2vec2(
        torch.cat([waveform, zero_pad], dim=1),
        attention_mask=torch.cat([torch.ones_like(waveform), zero_pad], dim=1),
    )[0].cpu().detach().numpy()
print(outputs[0, 0, 0])

The output will depend on the size of zero pad. It follows that batch size also affect logits and loss. And from this it follows that the validation loss will depend on samples order and batch size, which is inconsistent.

@i-amgeek
Copy link

i-amgeek commented Nov 1, 2024

Is there any update on this issue?

Just spent 6 hours in pinpointing this bug in a large codebase.

@ex3ndr
Copy link

ex3ndr commented Nov 22, 2024

Does anyone know any workarounds? i am using MMS_FA bundle for alignment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants