Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl/bestrq] norm the input and codebooks #2556

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ def __init__(

# encoder
self.encoder = encoder
assert self.encoder.global_cmvn is not None
self.register_buffer('signal_mean', self.encoder.global_cmvn.mean)
self.register_buffer('signal_istd', self.encoder.global_cmvn.istd)
self.signal_norm_var = self.encoder.global_cmvn.norm_var
# NOTE(Mddct): disable encoder's global_cmvn
self.encoder.global_cmvn = None

# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
Expand Down Expand Up @@ -169,10 +163,6 @@ def forward(
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
# force global cmvn
xs = xs - self.signal_mean
if self.signal_norm_var:
xs = xs * self.signal_istd
input = xs

features_pen: Optional[torch.Tensor] = None
Expand All @@ -186,6 +176,8 @@ def forward(
subsampling_masks = masked_masks.unfold(1,
size=self.stack_frames,
step=self.stride)
# NOTE(Mddct): you can try torch.max(subsampling_masks, 2) if
# subsampling rate == 2 or mask probs is smaller
code_ids_mask, _ = torch.min(subsampling_masks, 2)

# 2.0 stack fbank
Expand Down Expand Up @@ -267,10 +259,16 @@ def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.norm(xs)
if self.encoder.global_cmvn is None:
xs = self.norm(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))

xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings / (
self.embeddings.norm(dim=-1, p=2, keepdim=True) + 1e-8)
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, self.embeddings)
# _, codes, _ = quantize_vector(xs_flatten, codebooks)
distance = xs_flatten.unsqueeze(1).unsqueeze(1) - codebooks.unsqueeze(
0)
codes = torch.linalg.vector_norm(distance, dim=-1).argmin(dim=1)
return codes.reshape(B, T, -1) # [B, T, num_codebooks]
Loading