diff --git a/wenet/ssl/bestrq/bestrq_model.py b/wenet/ssl/bestrq/bestrq_model.py index f5791fc0e..97d9ecdab 100644 --- a/wenet/ssl/bestrq/bestrq_model.py +++ b/wenet/ssl/bestrq/bestrq_model.py @@ -81,12 +81,6 @@ def __init__( # encoder self.encoder = encoder - assert self.encoder.global_cmvn is not None - self.register_buffer('signal_mean', self.encoder.global_cmvn.mean) - self.register_buffer('signal_istd', self.encoder.global_cmvn.istd) - self.signal_norm_var = self.encoder.global_cmvn.norm_var - # NOTE(Mddct): disable encoder's global_cmvn - self.encoder.global_cmvn = None # n softmax self.encoder_top_n_out = torch.nn.parameter.Parameter( @@ -169,10 +163,6 @@ def forward( ): xs = batch['feats'].to(device) xs_lens = batch['feats_lengths'].to(device) - # force global cmvn - xs = xs - self.signal_mean - if self.signal_norm_var: - xs = xs * self.signal_istd input = xs features_pen: Optional[torch.Tensor] = None @@ -186,6 +176,8 @@ def forward( subsampling_masks = masked_masks.unfold(1, size=self.stack_frames, step=self.stride) + # NOTE(Mddct): you can try torch.max(subsampling_masks, 2) if + # subsampling rate == 2 or mask probs is smaller code_ids_mask, _ = torch.min(subsampling_masks, 2) # 2.0 stack fbank @@ -267,10 +259,16 @@ def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, return loss def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: - xs = self.norm(xs) + if self.encoder.global_cmvn is None: + xs = self.norm(xs) xs = torch.matmul(xs, self.projection.to(xs.device)) - + xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8) + codebooks = self.embeddings / ( + self.embeddings.norm(dim=-1, p=2, keepdim=True) + 1e-8) B, T, C = xs.size() xs_flatten = xs.view(B * T, C) - _, codes, _ = quantize_vector(xs_flatten, self.embeddings) + # _, codes, _ = quantize_vector(xs_flatten, codebooks) + distance = xs_flatten.unsqueeze(1).unsqueeze(1) - codebooks.unsqueeze( + 0) + codes = torch.linalg.vector_norm(distance, dim=-1).argmin(dim=1) return codes.reshape(B, T, -1) # [B, T, num_codebooks]