Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parts for the BestRQ mask + quantizer #63

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions i6_models/parts/best_rq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mask import *
from .quantizer import *
74 changes: 74 additions & 0 deletions i6_models/parts/best_rq/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
import numpy as np

__all__ = ["RandomMask"]


class RandomMask(nn.Module):
"""
randomly mask out consecutive frames time dimension, the masked frames can be either
replaced with zeros or with learnable embeddings.
simplified version from Fairseq compute_mask_indices function,
C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399
"""

def __init__(
self,
input_dim: int,
mask_replace_val: str,
mask_percentage: float,
mask_length: int,
):
"""
:param input_dim: number of feature dimension of input
:param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings
:param mask_percentage: percentage of frames to be masked out
:param mask_length: the length of each mask span
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()

assert mask_replace_val in ["lernable", "zero"], "not implemented yet"
if mask_replace_val == "lernable":
self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_())
elif mask_replace_val == "zero":
self.mask_emb = torch.zeros(input_dim)
self.mask_percentage = mask_percentage
self.mask_length = mask_length

def forward(
self,
tensor: torch.tensor,
padding_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim_batch, ndim_time, _ = tensor.size()

mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool)

mask_idcs = []
for i in range(ndim_batch):
if padding_mask is not None:
seq_len = ndim_time - padding_mask[i].long().sum().item()
assert seq_len >= 0
else:
seq_len = ndim_time

num_mask = int(
# add a random number for probabilistic rounding
self.mask_percentage * seq_len / float(self.mask_length)
+ np.random.rand()
)

min_len = self.mask_length
if seq_len - min_len <= num_mask:
min_len = seq_len - num_mask - 1
michelwi marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be incorrect. Let's say seq_len = 5, mask_percentage = 1.0, mask_length = 5, then num_mask could be 1. In this case, seq_len - mask_length = 0 <= 1, then min_len is assigned 5 - 1 -1 = 3. Through code in next line, np.random.choice(2, 1, replace=False) makes it possible that not 100% frames are masked out. Pls correct me if I'm wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow your logic and think that

if seq_len - min_len < num_mask:
     min_len = seq_len - num_mask

would be more true to the intended implementation.

But this case will most likely not appear as nobody would want to mask the entire sequence.

mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False)
michelwi marked this conversation as resolved.
Show resolved Hide resolved

for j in mask_idc:
mask[i, j : j + self.mask_length] = True

tensor[mask] = self.mask_emb.to(tensor.device)

return tensor, torch.tensor(mask).to(tensor.device)
37 changes: 37 additions & 0 deletions i6_models/parts/best_rq/quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm

__all__ = [
"RandomProjectionQuantizer",
]


class RandomProjectionQuantizer(nn.Module):
"""
implement the fixed random projection quantizer from BestRQ
C.f. https://arxiv.org/pdf/2202.01855 for theoretic background
code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127
"""

def __init__(self, input_dim, codebook_dim, codebook_num_vars):
"""
:param input_dim: number of feature dimension of input
:param codebook_dim: number of dimension for vocab in the codebook
:param codebook_num_vars: vocab size of the codebook
"""
super().__init__()

self.input_dim = input_dim

# projection matrix use Xavier initialization
P_init = torch.empty((input_dim, codebook_dim))
self.register_buffer("P", nn.init.xavier_uniform_(P_init))

# normalize random matrix for codebook
self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim)))

def forward(self, x: torch.tensor) -> torch.tensor:
x = F.normalize(x @ self.P)
return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)
michelwi marked this conversation as resolved.
Show resolved Hide resolved
Loading