Skip to content

Commit

Permalink
try to add n subsequent token with orign bestrq n softmax, eg n_softm…
Browse files Browse the repository at this point in the history
…ax X n_sub
  • Loading branch information
Mddct committed Sep 20, 2024
1 parent d7e2190 commit 74a6ebc
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions wenet/ssl/nestrq/nestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.utils.mask import make_non_pad_mask


class NestRQModel(torch.nn.Module):
Expand All @@ -18,6 +19,7 @@ def __init__(
embedding_dim: int = 16,
num_embeddings: int = 8192,
num_codebooks: int = 1,
n_subsequent: int = 1,
out_bias: bool = False,
) -> None:
super().__init__()
Expand All @@ -28,13 +30,13 @@ def __init__(
self.encoder = encoder
# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, self.encoder.output_size(),
num_embeddings))
torch.empty(n_subsequent, self.num_codebooks,
self.encoder.output_size(), num_embeddings))
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
self.out_bias = out_bias
if self.out_bias:
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, num_embeddings))
torch.empty(n_subsequent, self.num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)

# stack input: eg: fbank
Expand All @@ -52,6 +54,8 @@ def __init__(
eps=1e-6,
elementwise_affine=False,
bias=False)
# Section: 1B
self.n_subsequent = n_subsequent

# codebook
# [num_embeddings, num_codebooks, num_embeddings] means
Expand Down Expand Up @@ -114,28 +118,31 @@ def forward(

# 1 stack fbank, out_mask is for compute loss (NPT)
stack_input, stack_out_mask = self._stack_features(input, xs_lens)
masked_xs = xs

# 2 get nearest embedding
target_ids = self._nearest_embedding_idx(stack_input)
target_ids = target_ids[:, :out_mask.size(1), :]
target_ids = target_ids[:, :stack_out_mask.size(1), :]
target_ids = target_ids.unfold(1, size=self.n_subsequent,
step=1).transpose(-1,
-2) # (B,T,-1, vocab)

# 3 forward xxx-formaer block and its subsampling layer
# TODO(mddct): encoder causal mask
out, out_mask = self.encoder(masked_xs, xs_lens)
out, out_mask = self.encoder(xs, xs_lens)

# 4 get logits
out = out.unsqueeze(1) # [B, 1, T', dim]
out = out.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [1, num_codebooks, dim, num_embeddings]
out = torch.matmul(out,
top_n_out) # [B, num_codebooks, T', num_embeddings]
0) # [1, n_subsequent, num_codebooks, dim, num_embeddings]
out = torch.matmul(
out,
top_n_out) # [B, n_subsequent, num_codebooks, T', num_embeddings]
if self.out_bias:
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(3)

# shift input and target for next token prediction
out = out[:, :, :-1:]
target_ids = target_ids[:, 1:, :]
out = out[:, :, :, :target_ids.size(1), :]
target_ids = target_ids[:, 1:, :, :]
masks = out_mask.squeeze(1) * stack_out_mask
masks = masks[:, 1:]

Expand All @@ -160,7 +167,6 @@ def forward(
def _stack_features(
self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

mask = make_non_pad_mask(input_lens)
mask_stride = mask.unfold(
1,
Expand All @@ -178,8 +184,10 @@ def _stack_features(

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
mask = mask.unsqueeze(2).repeat(1, 1, self.num_codebooks)
logits = input.contiguous().permute(
(0, 3, 1, 2, 4)).view(-1, input.size(-1))
mask = mask.unsqueeze(2).unsqueeze(2).repeat(1, 1, self.n_subsequent,
self.num_codebooks)
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
Expand Down

0 comments on commit 74a6ebc

Please sign in to comment.