Skip to content

Commit

Permalink
specaug speedup (#6347)
Browse files Browse the repository at this point in the history
* [Core] return_config=True now extracts just config, not full tarfile (#6346)

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: shane carroll <shane.carroll@utsa.edu>

* specaug speedup

Signed-off-by: shane carroll <shane.carroll@utsa.edu>

* comments

Signed-off-by: shane carroll <shane.carroll@utsa.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: shane carroll <shane.carroll@utsa.edu>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: shane carroll <shane.carroll@utsa.edu>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 6, 2023
1 parent 515d36b commit 8aec729
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions nemo/collections/asr/parts/submodules/spectr_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import random

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -80,29 +81,36 @@ def __init__(
@typecheck()
@torch.no_grad()
def forward(self, input_spec, length):
sh = input_spec.shape

for idx in range(sh[0]):
for i in range(self.freq_masks):
x_left = self._rng.randint(0, sh[1] - self.freq_width)

w = self._rng.randint(0, self.freq_width)

input_spec[idx, x_left : x_left + w, :] = self.mask_value

for i in range(self.time_masks):
if self.adaptive_temporal_width:
time_width = max(1, int(length[idx] * self.time_width))
else:
time_width = self.time_width

y_left = self._rng.randint(0, max(1, length[idx] - time_width))

w = self._rng.randint(0, time_width)

input_spec[idx, :, y_left : y_left + w] = self.mask_value

return input_spec
batch_size, num_freq_bins, _ = input_spec.shape
# Move lengths to CPU before repeated indexing
lengths_cpu = length.cpu().numpy()
# Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
freq_start_upper_bound = num_freq_bins - self.freq_width
# Choose different mask ranges for each element of the batch
for idx in range(batch_size):
# Set freq masking
for _ in range(self.freq_masks):
start = self._rng.randint(0, freq_start_upper_bound)
width = self._rng.randint(0, self.freq_width)
fill_mask[idx, start : start + width, :] = True

# Derive time width, sometimes based percentage of input length.
if self.adaptive_temporal_width:
time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
else:
time_max_width = self.time_width
time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)

# Set time masking
for _ in range(self.time_masks):
start = self._rng.randint(0, time_start_upper_bound)
width = self._rng.randint(0, time_max_width)
fill_mask[idx, :, start : start + width] = True
# Bring the mask to device and fill spec
fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
return masked_spec


class SpecCutout(nn.Module, Typing):
Expand Down

0 comments on commit 8aec729

Please sign in to comment.