Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] Improve SpecAugment function by converting numpy based fun… #10494

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 18 additions & 37 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _compute_mask_indices(
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
) -> torch.tensor:
"""
Computes random mask spans for a given shape

Expand All @@ -68,12 +68,12 @@ def _compute_mask_indices(
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
mask = torch.Tensor(bsz, all_sz).fill_(False)

all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
+ torch.rand()
)

all_num_mask = max(min_masks, all_num_mask)
Expand All @@ -86,14 +86,14 @@ def _compute_mask_indices(
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
+ torch.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask

lengths = np.full(num_mask, mask_length)
lengths = torch.Tensor(num_mask).fill_(mask_length)

if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
Expand All @@ -102,14 +102,15 @@ def _compute_mask_indices(
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1

mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
mask_idc = torch.randperm(sz - min_len)[:num_mask]
mask_idc = torch.tensor([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])

min_len = min([len(m) for m in mask_idcs])
mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))

min_len = torch.min(mask_idcs)
for i, mask_idc in enumerate(mask_idcs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to get rid of the for-loop and do tensor operations only

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how to do this. Something like following but not getting how should I put it ? Can you help here.

mask[i, mask_idc] = [True, torch.randperm(mask_idc)[:min_len] if torch.tensor(mask_idcs).size() > min_len]

if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask_idc = torch.randperm(mask_idc)[:min_len]
mask[i, mask_idc] = True

return mask
Expand Down Expand Up @@ -274,12 +275,7 @@ class Wav2Vec2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
Expand Down Expand Up @@ -563,9 +559,7 @@ def custom_forward(*inputs):
return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
create_custom_forward(layer), hidden_states, attention_mask,
)
else:
layer_outputs = layer(
Expand All @@ -582,9 +576,7 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions,
)


Expand Down Expand Up @@ -642,9 +634,7 @@ def custom_forward(*inputs):
return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
create_custom_forward(layer), hidden_states, attention_mask,
)
else:
layer_outputs = layer(
Expand All @@ -663,9 +653,7 @@ def custom_forward(*inputs):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions,
)


Expand Down Expand Up @@ -788,12 +776,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None,
):
"""

Expand Down Expand Up @@ -862,9 +845,7 @@ def forward(
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
(batch_size, hidden_size), self.config.mask_feature_prob, self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
Expand Down