diff --git a/test/wenet/transformer/test_attention.py b/test/wenet/transformer/test_attention.py index ca7e63d92..8be823664 100644 --- a/test/wenet/transformer/test_attention.py +++ b/test/wenet/transformer/test_attention.py @@ -3,7 +3,7 @@ from wenet.transformer.attention import (MultiHeadedAttention, RelPositionMultiHeadedAttention, ShawRelPositionMultiHeadedAttention) -from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding, W2vbertPositionalEncoding from wenet.transformer.encoder_layer import (ConformerEncoderLayer, TransformerEncoderLayer) from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward @@ -233,10 +233,11 @@ def test_shaw_rel_position_multihead_attention(): 256, 0.0, use_sdpa=True) + pos_emb_class = W2vbertPositionalEncoding(256, 0.1, 5000) q = torch.rand(2, 10, 256) k = torch.rand(2, 10, 256) v = torch.rand(2, 10, 256) - pos_emb = torch.zeros(0, 0, 0) + q, pos_emb = pos_emb_class(q) mask = torch.ones(2, 10, 10) out, _ = module(q, k, v, mask, pos_emb) out_sdpa, _ = module_sdpa(q, k, v, mask, pos_emb) diff --git a/wenet/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py b/wenet/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py index 1dcf128c3..1dbd79658 100644 --- a/wenet/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py +++ b/wenet/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py @@ -43,7 +43,8 @@ def convert_to_wenet_yaml(wenet_yaml_path: str): # TODO(Mddct): To use whisper's decoder here configs['decoder'] = 'transformer' configs['decoder_conf'] = {} - configs['decoder_conf']['attention_head'] = 16 + configs['decoder_conf']['attention_heads'] = 16 + configs['encoder_conf']['gradient_checkpointing'] = True configs['decoder_conf']['linear_units'] = 4096 configs['decoder_conf']['num_blocks'] = 6 configs['decoder_conf']['dropout_rate'] = 0.1 @@ -91,14 +92,20 @@ def convert_to_wenet_yaml(wenet_yaml_path: str): configs['dataset_conf']['sort_conf'] = {} configs['dataset_conf']['sort_conf']['sort_size'] = 500 configs['dataset_conf']['feats_type'] = "fbank" + configs['dataset_conf']['fbank_conf']['num_mel_bins'] = 80 + configs['dataset_conf']['fbank_conf']['frame_shift'] = 10 + configs['dataset_conf']['fbank_conf']['frame_length'] = 25 + configs['dataset_conf']['fbank_conf']['dither'] = 0.1 + configs['dataset_conf']['batch_conf'] = {} configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' - configs['dataset_conf']['batch_conf']['batch_size'] = 26 configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 + # TODO: Tokenizer or not + configs['grad_clip'] = 5 - configs['accum_grad'] = 4 - configs['max_epoch'] = 100 + configs['accum_grad'] = 1 + configs['max_epoch'] = 40 configs['log_interval'] = 100 configs['optim'] = "adam" diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 54b76daad..eb16affeb 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -545,13 +545,6 @@ def __init__(self, self.rel_k_embed = torch.nn.Embedding( self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k) - def _relative_indices(self, length: int, device: torch.device): - indices = torch.arange(length, device=device).unsqueeze(0) - rel_indices = indices - indices.transpose(0, 1) - rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos, - self.max_right_rel_pos) - return rel_indices + self.max_left_rel_pos - def forward( self, query: torch.Tensor, @@ -561,7 +554,6 @@ def forward( pos_emb: torch.Tensor = torch.empty(0), cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) ) -> Tuple[torch.Tensor, torch.Tensor]: - del pos_emb q, k, v = self.forward_qkv(query, key, value) if cache.size(0) > 0: key_cache, value_cache = torch.split(cache, @@ -571,8 +563,8 @@ def forward( v = torch.cat([value_cache, v], dim=2) new_cache = torch.cat((k, v), dim=-1) - rel_k = self.rel_k_embed( - self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k) + pos_emb = pos_emb[:k.size(2), :k.size(2)] + rel_k = self.rel_k_embed(pos_emb) # (t2, t2, d_k) rel_k = rel_k[-q.size(2):] # (t1, t2, d_k) # b,h,t1,dk rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k) diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 2efa2f5fd..fcdc79f93 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -234,3 +234,39 @@ def dropout_complex(self, x): p=self.dropout_rate, ) return x * mask + + +class W2vbertPositionalEncoding(PositionalEncoding): + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + super().__init__(d_model, dropout_rate, max_len, reverse) + delattr(self, 'pe') + self.max_right_rel_pos = 64 + self.max_left_rel_pos = 8 + pe = self._relative_indices(max_len) + self.register_buffer('pe', pe) + + def _relative_indices(self, length: torch.Tensor): + indices = torch.arange(length).unsqueeze(0) + rel_indices = indices - indices.transpose(0, 1) + rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos, + self.max_right_rel_pos) + return rel_indices + self.max_left_rel_pos + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + return self.pe + + def forward( + self, + x: torch.Tensor, + offset: Union[int, + torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: + _ = offset + return x, self.pe diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py index 5619e8cf7..aedb9ebf3 100644 --- a/wenet/transformer/subsampling.py +++ b/wenet/transformer/subsampling.py @@ -373,18 +373,17 @@ def forward( where time' = time // stride. torch.Tensor: positional encoding """ - with torch.no_grad(): - b, s, _ = x.size() - - seq_len = x_mask.sum(-1).view(b) - r = s % self.stride - s -= r - x = x[:, :s, :] - seq_len = torch.where(seq_len > s, s, seq_len) - seq_len = seq_len // self.stride - new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride) - x = x.view(b, s // self.stride, self.idim * self.stride) - _, pos_emb = self.pos_enc_class(x, offset) - x = self.norm(x) - x = self.out(x) + b, s, _ = x.size() + + seq_len = x_mask.sum(-1).view(b) + r = s % self.stride + s -= r + x = x[:, :s, :] + seq_len = torch.where(seq_len > s, s, seq_len) + seq_len = seq_len // self.stride + new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride) + x = x.view(b, s // self.stride, self.idim * self.stride) + _, pos_emb = self.pos_enc_class(x, offset) + x = self.norm(x) + x = self.out(x) return x, pos_emb, new_mask.unsqueeze(1) diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 7800ec7c8..27c832122 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -20,12 +20,10 @@ ) from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 -from wenet.transformer.embedding import (PositionalEncoding, - RelPositionalEncoding, - RopePositionalEncoding, - WhisperPositionalEncoding, - LearnablePositionalEncoding, - NoPositionalEncoding) +from wenet.transformer.embedding import ( + PositionalEncoding, RelPositionalEncoding, RopePositionalEncoding, + W2vbertPositionalEncoding, WhisperPositionalEncoding, + LearnablePositionalEncoding, NoPositionalEncoding) from wenet.transformer.attention import (MultiHeadedAttention, MultiHeadedCrossAttention, RelPositionMultiHeadedAttention, @@ -71,6 +69,7 @@ "embed_learnable_pe": LearnablePositionalEncoding, "abs_pos_paraformer": ParaformerPositinoalEncoding, 'rope_pos': RopePositionalEncoding, + "w2vbert_pos": W2vbertPositionalEncoding, } WENET_ATTENTION_CLASSES = {