Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl/wav2vec2] support wav2vec2 #2034

Merged
merged 1 commit into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(
torch.nn.init.xavier_uniform_(self.projection)

# codebooks
# [num_codebooks, embedding_dim, num_embeddings]
# [num_embeddings, num_codebooks, num_embeddings] means
# [C, G, D] see quantize_vector
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(num_embeddings, self.num_codebooks, embedding_dim),
requires_grad=False,
Expand Down
109 changes: 109 additions & 0 deletions wenet/ssl/wav2vec2/quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch


def gumbel(shape: torch.Size, dtype: torch.dtype, device: torch.device):
"""Sample Gumbel random values with given shape and float dtype.

The values are distributed according to the probability density function:

.. math::
f(x) = e^{-(x + e^{-x})}

Args:
shape (torch.Size): pdf shape
dtype (torch.dtype): pdf value dtype

Returns:
A random array with the specified shape and dtype.
"""
# see https://www.cnblogs.com/initial-h/p/9468974.html for more details
return -torch.log(-torch.log(
torch.empty(shape, device=device).uniform_(
torch.finfo(dtype).tiny, 1.)))


class Wav2vecGumbelVectorQuantizer(torch.nn.Module):

def __init__(self,
features_dim: int = 256,
num_codebooks: int = 2,
num_embeddings: int = 8192,
embedding_dim: int = 16,
hard: bool = False) -> None:

super().__init__()

self.num_groups = num_codebooks
self.num_codevectors_per_group = num_embeddings
# codebooks
# means [C, G, D] see quantize_vector in bestrq_model.py
assert embedding_dim % num_codebooks == 0.0
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(1, num_codebooks * num_embeddings,
embedding_dim // num_codebooks),
requires_grad=True,
)
torch.nn.init.uniform_(self.embeddings)

self.weight_proj = torch.nn.Linear(features_dim,
num_codebooks * num_embeddings)
# use gumbel softmax or argmax(non-differentiable)
self.hard = hard

@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:

mask_extended = torch.broadcast_to(mask.flatten()[:, None, None],
probs.shape)
probs = torch.where(mask_extended.to(torch.bool), probs,
torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)

perplexity = torch.exp(-torch.sum(
marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity

def forward(self,
input: torch.Tensor,
input_mask: torch.Tensor,
temperature: float = 1.):

b, t, _ = input.size()

hidden = self.weight_proj(input)
hidden = hidden.reshape(b * t * self.num_groups, -1)
if not self.hard:
# sample code vector probs via gumbel in differentiateable way
gumbels = gumbel(hidden.size(), hidden.dtype, hidden.device)
codevector_probs = torch.nn.functional.softmax(
(hidden + gumbels) / temperature, dim=-1)

# compute perplexity
codevector_soft_dist = torch.nn.functional.softmax(
hidden.reshape(b * t, self.num_groups, -1),
dim=-1,
) # [B*T, num_codebooks, num_embeddings]
perplexity = self._compute_perplexity(codevector_soft_dist,
input_mask)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden.argmax(axis=-1)
codevector_probs = torch.nn.functional.one_hot(
codevector_idx, hidden.shape[-1]) * 1.0
codevector_probs = codevector_probs.reshape(
b * t, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, input_mask)

codevector_probs = codevector_probs.reshape(b * t, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(
-1) * self.embeddings
codevectors = codevectors_per_group.reshape(
b * t, self.num_groups, self.num_codevectors_per_group, -1)

codevectors = codevectors.sum(-2).reshape(b, t, -1)
return codevectors, perplexity
Loading
Loading