From 1972d99409c2a48467bc7eb570a0e3a1a4d4b3e4 Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Mon, 11 Nov 2024 17:16:18 +0100 Subject: [PATCH 1/5] add best rq part --- i6_models/parts/best_rq/mask.py | 55 ++++++++++++++++++++++++++++ i6_models/parts/best_rq/quantizer.py | 24 ++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 i6_models/parts/best_rq/mask.py create mode 100644 i6_models/parts/best_rq/quantizer.py diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py new file mode 100644 index 00000000..59106a8c --- /dev/null +++ b/i6_models/parts/best_rq/mask.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +from typing import Optional +import numpy as np + + +class RandomMask(nn.Module): + def __init__(self, input_dim, mask_replace_val): + + if mask_replace_val == "lernable": + self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_()) + elif mask_replace_val == 0: + self.mask_emb = torch.zeros(input_dim) + + def forward( + self, + tensor: torch.tensor, + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + min_masks: int = 0, + ): + ndim_batch, ndim_time, _ = tensor.size() + + mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool) + + mask_idcs = [] + for i in range(ndim_batch): + if padding_mask is not None: + seq_len = ndim_time - padding_mask[i].long().sum().item() + assert seq_len >= 0 + else: + seq_len = ndim_time + + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * seq_len / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + + min_len = mask_length + if seq_len - min_len <= num_mask: + min_len = seq_len - num_mask - 1 + mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) + + mask_idc = np.asarray([mask_idc[j] + mask_length for j in range(len(mask_idc))]) + mask_idcs.append(mask_idc) + + for i, mask_idc in enumerate(mask_idcs): + mask[i, mask_idc] = True + + tensor[mask] = self.mask_emb + + return tensor diff --git a/i6_models/parts/best_rq/quantizer.py b/i6_models/parts/best_rq/quantizer.py new file mode 100644 index 00000000..0d039ba5 --- /dev/null +++ b/i6_models/parts/best_rq/quantizer.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.linalg import vector_norm + + +class RandomProjectionQuantizer(nn.Module): + def __init__(self, input_dim, cb_dim, cb_vocab): + super().__init__() + + self.input_dim = input_dim + self.cb_dim = cb_dim + self.cb_vocab = cb_vocab + + # Section 3.1 "projection matrix A use Xavier initialization" + P_init = torch.empty((input_dim, cb_dim)) + self.register_buffer("P", nn.init.xavier_uniform_(P_init)) + + # normalize random matrix for codebook + self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim))) + + def forward(self, x): + x = F.normalize(x @ self.P) + return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1) From 8b3ed957eac8a7f373627c11b04523f36c1a944f Mon Sep 17 00:00:00 2001 From: Jingjing Xu Date: Mon, 18 Nov 2024 10:13:19 -0500 Subject: [PATCH 2/5] update --- i6_models/parts/best_rq/__init__.py | 2 + i6_models/parts/best_rq/mask.py | 56 +++++++++++++++++++++------- i6_models/parts/best_rq/quantizer.py | 27 ++++++++++---- 3 files changed, 64 insertions(+), 21 deletions(-) create mode 100644 i6_models/parts/best_rq/__init__.py diff --git a/i6_models/parts/best_rq/__init__.py b/i6_models/parts/best_rq/__init__.py new file mode 100644 index 00000000..fb9a81db --- /dev/null +++ b/i6_models/parts/best_rq/__init__.py @@ -0,0 +1,2 @@ +from .mask import * +from .quantizer import * diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 59106a8c..176ec2c3 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -1,25 +1,51 @@ +from typing import Optional, Tuple + import torch import torch.nn as nn -from typing import Optional import numpy as np +__all__ = ["RandomMask"] + class RandomMask(nn.Module): - def __init__(self, input_dim, mask_replace_val): + """ + randomly mask out consecutive frames time dimension, the masked frames can be either + replaced with zeros or with learnable embeddings. + simplified version from Fairseq compute_mask_indices function, + C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399 + """ + def __init__( + self, + input_dim: int, + mask_replace_val: str, + mask_prob: float, + mask_length: int, + min_masks: int = 0, + ): + """ + :param input_dim: number of feature dimension of input + :param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings + :param mask_prob: percentage of frames to be masked out + :param mask_length: the length of each mask span + :param min_masks: minimum number of masks + """ + super().__init__() + + assert mask_replace_val in ["lernable", "zero"], "not implemented yet" if mask_replace_val == "lernable": self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_()) - elif mask_replace_val == 0: + elif mask_replace_val == "zero": self.mask_emb = torch.zeros(input_dim) + self.mask_prob = mask_prob + self.mask_length = mask_length + self.min_masks = min_masks def forward( self, tensor: torch.tensor, padding_mask: Optional[torch.Tensor], - mask_prob: float, - mask_length: int, - min_masks: int = 0, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: ndim_batch, ndim_time, _ = tensor.size() mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool) @@ -34,22 +60,24 @@ def forward( num_mask = int( # add a random number for probabilistic rounding - mask_prob * seq_len / float(mask_length) + self.mask_prob * seq_len / float(self.mask_length) + np.random.rand() ) - num_mask = max(min_masks, num_mask) + num_mask = max(self.min_masks, num_mask) - min_len = mask_length + min_len = self.mask_length if seq_len - min_len <= num_mask: min_len = seq_len - num_mask - 1 mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) - mask_idc = np.asarray([mask_idc[j] + mask_length for j in range(len(mask_idc))]) - mask_idcs.append(mask_idc) + mask_idc = np.asarray( + [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(self.mask_length)] + ) + mask_idcs.append(mask_idc) for i, mask_idc in enumerate(mask_idcs): mask[i, mask_idc] = True - tensor[mask] = self.mask_emb + tensor[mask] = self.mask_emb.to(tensor.device) - return tensor + return tensor, torch.tensor(mask).to(tensor.device) diff --git a/i6_models/parts/best_rq/quantizer.py b/i6_models/parts/best_rq/quantizer.py index 0d039ba5..8633eb42 100644 --- a/i6_models/parts/best_rq/quantizer.py +++ b/i6_models/parts/best_rq/quantizer.py @@ -3,22 +3,35 @@ import torch.nn.functional as F from torch.linalg import vector_norm +__all__ = [ + "RandomProjectionQuantizer", +] + class RandomProjectionQuantizer(nn.Module): - def __init__(self, input_dim, cb_dim, cb_vocab): + """ + implement the fixed random projection quantizer from BestRQ + C.f. https://arxiv.org/pdf/2202.01855 for theoretic background + code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127 + """ + + def __init__(self, input_dim, codebook_dim, codebook_num_vars): + """ + :param input_dim: number of feature dimension of input + :param codebook_dim: number of dimension for vocab in the codebook + :param codebook_num_vars: vocab size of the codebook + """ super().__init__() self.input_dim = input_dim - self.cb_dim = cb_dim - self.cb_vocab = cb_vocab - # Section 3.1 "projection matrix A use Xavier initialization" - P_init = torch.empty((input_dim, cb_dim)) + # projection matrix use Xavier initialization + P_init = torch.empty((input_dim, codebook_dim)) self.register_buffer("P", nn.init.xavier_uniform_(P_init)) # normalize random matrix for codebook - self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim))) + self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim))) - def forward(self, x): + def forward(self, x: torch.tensor) -> torch.tensor: x = F.normalize(x @ self.P) return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1) From b34b1dcd52ef211588bf81279c139dd96c843bc3 Mon Sep 17 00:00:00 2001 From: Jingjing Xu Date: Wed, 4 Dec 2024 14:35:13 +0000 Subject: [PATCH 3/5] update --- i6_models/parts/best_rq/mask.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 176ec2c3..85346ebf 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -19,16 +19,14 @@ def __init__( self, input_dim: int, mask_replace_val: str, - mask_prob: float, + mask_percentage: float, mask_length: int, - min_masks: int = 0, ): """ :param input_dim: number of feature dimension of input :param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings - :param mask_prob: percentage of frames to be masked out + :param mask_percentage: percentage of frames to be masked out :param mask_length: the length of each mask span - :param min_masks: minimum number of masks """ super().__init__() @@ -37,9 +35,8 @@ def __init__( self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_()) elif mask_replace_val == "zero": self.mask_emb = torch.zeros(input_dim) - self.mask_prob = mask_prob + self.mask_percentage = mask_percentage self.mask_length = mask_length - self.min_masks = min_masks def forward( self, @@ -60,10 +57,9 @@ def forward( num_mask = int( # add a random number for probabilistic rounding - self.mask_prob * seq_len / float(self.mask_length) + self.mask_percentage * seq_len / float(self.mask_length) + np.random.rand() ) - num_mask = max(self.min_masks, num_mask) min_len = self.mask_length if seq_len - min_len <= num_mask: From b50b66eba5758de555d893c307f377e959d8ae90 Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Wed, 4 Dec 2024 22:59:31 +0800 Subject: [PATCH 4/5] Update i6_models/parts/best_rq/mask.py Co-authored-by: michelwi --- i6_models/parts/best_rq/mask.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 85346ebf..d830e702 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -66,13 +66,8 @@ def forward( min_len = seq_len - num_mask - 1 mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) - mask_idc = np.asarray( - [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(self.mask_length)] - ) - mask_idcs.append(mask_idc) - - for i, mask_idc in enumerate(mask_idcs): - mask[i, mask_idc] = True + for j in mask_idc: + mask[i, j : j+self.mask_length] = True tensor[mask] = self.mask_emb.to(tensor.device) From 40c369bf6a44718b3e5b2a69f0c3ab64f3fbb539 Mon Sep 17 00:00:00 2001 From: Jingjing Xu Date: Wed, 4 Dec 2024 15:08:30 +0000 Subject: [PATCH 5/5] black --- i6_models/parts/best_rq/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index d830e702..2e88479b 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -67,7 +67,7 @@ def forward( mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) for j in mask_idc: - mask[i, j : j+self.mask_length] = True + mask[i, j : j + self.mask_length] = True tensor[mask] = self.mask_emb.to(tensor.device)