Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace unbiased with correction #10555

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions nemo/collections/asr/parts/submodules/tdnn_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StatsPoolLayer(nn.Module):
pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
average pooling, i.e., mean)
eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
biased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
for torch.Tensor.std() is True.

Returns:
Expand All @@ -42,15 +42,15 @@ class StatsPoolLayer(nn.Module):
ValueError if an unsupported pooling mode is specified.
"""

def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, biased: bool = True):
def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True):
super().__init__()
supported_modes = {"xvector", "tap"}
if pool_mode not in supported_modes:
raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
self.pool_mode = pool_mode
self.feat_in = feat_in
self.eps = eps
self.biased = biased
self.unbiased = unbiased
if self.pool_mode == 'xvector':
# Mean + std
self.feat_in *= 2
Expand All @@ -59,7 +59,8 @@ def forward(self, encoder_output, length=None):
if length is None:
mean = encoder_output.mean(dim=-1) # Time Axis
if self.pool_mode == 'xvector':
std = encoder_output.std(dim=-1)
correction = 1 if self.unbiased else 0
std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
pooled = torch.cat([mean, std], dim=-1)
else:
pooled = mean
Expand All @@ -71,12 +72,13 @@ def forward(self, encoder_output, length=None):
# Re-scale to get padded means
means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
if self.pool_mode == "xvector":
correction = 1 if self.unbiased else 0
stds = (
encoder_output.sub(means.unsqueeze(-1))
.masked_fill(mask, 0.0)
.pow(2.0)
.sum(-1) # [B, D, T] -> [B, D]
.div(length.view(-1, 1).sub(1 if self.biased else 0))
.div(length.view(-1, 1).sub(correction))
.clamp(min=self.eps)
.sqrt()
)
Expand Down Expand Up @@ -104,7 +106,7 @@ def make_seq_mask_like(

def lens_to_mask(lens: List[int], max_len: int, device: str = None):
"""
outputs masking labels for list of lengths of audio features, with max length of any
outputs masking labels for list of lengths of audio features, with max length of any
mask as max_len
input:
lens: list of lens
Expand All @@ -124,8 +126,8 @@ def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps
"""
compute mean and standard deviation of input(x) provided with its masking labels (m)
input:
x: feature input
m: averaged mask labels
x: feature input
m: averaged mask labels
output:
mean: mean of input features
std: stadard deviation of input features
Expand All @@ -146,7 +148,7 @@ class TDNNModule(nn.Module):
stride: stride for conv layer
padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches)
output:
tdnn layer output
tdnn layer output
"""

def __init__(
Expand Down Expand Up @@ -183,7 +185,7 @@ class MaskedSEModule(nn.Module):
"""
Squeeze and Excite module implementation with conv1d layers
input:
inp_filters: input filter channel size
inp_filters: input filter channel size
se_filters: intermediate squeeze and excite channel output and input size
out_filters: output filter channel size
kernel_size: kernel_size for both conv1d layers
Expand All @@ -196,10 +198,20 @@ class MaskedSEModule(nn.Module):
def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
super().__init__()
self.se_layer = nn.Sequential(
nn.Conv1d(inp_filters, se_filters, kernel_size=kernel_size, dilation=dilation,),
nn.Conv1d(
inp_filters,
se_filters,
kernel_size=kernel_size,
dilation=dilation,
),
nn.ReLU(),
nn.BatchNorm1d(se_filters),
nn.Conv1d(se_filters, out_filters, kernel_size=kernel_size, dilation=dilation,),
nn.Conv1d(
se_filters,
out_filters,
kernel_size=kernel_size,
dilation=dilation,
),
nn.Sigmoid(),
)

Expand All @@ -220,7 +232,7 @@ class TDNNSEModule(nn.Module):
Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
inputs:
inp_filters: input filter channel size
inp_filters: input filter channel size
out_filters: output filter channel size
group_scale: scale value to group wider conv channels (deafult:8)
se_channels: squeeze and excite output channel size (deafult: 1024/8= 128)
Expand Down Expand Up @@ -276,7 +288,7 @@ class AttentivePoolLayer(nn.Module):
inp_filters: input feature channel length from encoder
attention_channels: intermediate attention channel size
kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1)
dilation: dilation size for TDNN and attention conv1d layers (default: 1)
dilation: dilation size for TDNN and attention conv1d layers (default: 1)
"""

def __init__(
Expand All @@ -295,7 +307,10 @@ def __init__(
TDNNModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation),
nn.Tanh(),
nn.Conv1d(
in_channels=attention_channels, out_channels=inp_filters, kernel_size=kernel_size, dilation=dilation,
in_channels=attention_channels,
out_channels=inp_filters,
kernel_size=kernel_size,
dilation=dilation,
),
)
self.eps = eps
Expand Down
Loading