Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[w2vbert] fintune: make fintune run #2409

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from wenet.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention,
ShawRelPositionMultiHeadedAttention)
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.embedding import RelPositionalEncoding, W2vbertPositionalEncoding
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
Expand Down Expand Up @@ -233,10 +233,11 @@ def test_shaw_rel_position_multihead_attention():
256,
0.0,
use_sdpa=True)
pos_emb_class = W2vbertPositionalEncoding(256, 0.1, 5000)
q = torch.rand(2, 10, 256)
k = torch.rand(2, 10, 256)
v = torch.rand(2, 10, 256)
pos_emb = torch.zeros(0, 0, 0)
q, pos_emb = pos_emb_class(q)
mask = torch.ones(2, 10, 10)
out, _ = module(q, k, v, mask, pos_emb)
out_sdpa, _ = module_sdpa(q, k, v, mask, pos_emb)
Expand Down
15 changes: 11 additions & 4 deletions wenet/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def convert_to_wenet_yaml(wenet_yaml_path: str):
# TODO(Mddct): To use whisper's decoder here
configs['decoder'] = 'transformer'
configs['decoder_conf'] = {}
configs['decoder_conf']['attention_head'] = 16
configs['decoder_conf']['attention_heads'] = 16
configs['encoder_conf']['gradient_checkpointing'] = True
configs['decoder_conf']['linear_units'] = 4096
configs['decoder_conf']['num_blocks'] = 6
configs['decoder_conf']['dropout_rate'] = 0.1
Expand Down Expand Up @@ -91,14 +92,20 @@ def convert_to_wenet_yaml(wenet_yaml_path: str):
configs['dataset_conf']['sort_conf'] = {}
configs['dataset_conf']['sort_conf']['sort_size'] = 500
configs['dataset_conf']['feats_type'] = "fbank"
configs['dataset_conf']['fbank_conf']['num_mel_bins'] = 80
configs['dataset_conf']['fbank_conf']['frame_shift'] = 10
configs['dataset_conf']['fbank_conf']['frame_length'] = 25
configs['dataset_conf']['fbank_conf']['dither'] = 0.1

configs['dataset_conf']['batch_conf'] = {}
configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic'
configs['dataset_conf']['batch_conf']['batch_size'] = 26
configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000

# TODO: Tokenizer or not

configs['grad_clip'] = 5
configs['accum_grad'] = 4
configs['max_epoch'] = 100
configs['accum_grad'] = 1
configs['max_epoch'] = 40
configs['log_interval'] = 100

configs['optim'] = "adam"
Expand Down
12 changes: 2 additions & 10 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,6 @@ def __init__(self,
self.rel_k_embed = torch.nn.Embedding(
self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k)

def _relative_indices(self, length: int, device: torch.device):
indices = torch.arange(length, device=device).unsqueeze(0)
rel_indices = indices - indices.transpose(0, 1)
rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos,
self.max_right_rel_pos)
return rel_indices + self.max_left_rel_pos

def forward(
self,
query: torch.Tensor,
Expand All @@ -561,7 +554,6 @@ def forward(
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
Expand All @@ -571,8 +563,8 @@ def forward(
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
pos_emb = pos_emb[:k.size(2), :k.size(2)]
rel_k = self.rel_k_embed(pos_emb) # (t2, t2, d_k)
rel_k = rel_k[-q.size(2):] # (t1, t2, d_k)
# b,h,t1,dk
rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k)
Expand Down
36 changes: 36 additions & 0 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,39 @@ def dropout_complex(self, x):
p=self.dropout_rate,
)
return x * mask


class W2vbertPositionalEncoding(PositionalEncoding):

def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000,
reverse: bool = False):
super().__init__(d_model, dropout_rate, max_len, reverse)
delattr(self, 'pe')
self.max_right_rel_pos = 64
self.max_left_rel_pos = 8
pe = self._relative_indices(max_len)
self.register_buffer('pe', pe)

def _relative_indices(self, length: torch.Tensor):
indices = torch.arange(length).unsqueeze(0)
rel_indices = indices - indices.transpose(0, 1)
rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos,
self.max_right_rel_pos)
return rel_indices + self.max_left_rel_pos

def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:
return self.pe

def forward(
self,
x: torch.Tensor,
offset: Union[int,
torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:
_ = offset
return x, self.pe
27 changes: 13 additions & 14 deletions wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,17 @@ def forward(
where time' = time // stride.
torch.Tensor: positional encoding
"""
with torch.no_grad():
b, s, _ = x.size()

seq_len = x_mask.sum(-1).view(b)
r = s % self.stride
s -= r
x = x[:, :s, :]
seq_len = torch.where(seq_len > s, s, seq_len)
seq_len = seq_len // self.stride
new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride)
x = x.view(b, s // self.stride, self.idim * self.stride)
_, pos_emb = self.pos_enc_class(x, offset)
x = self.norm(x)
x = self.out(x)
b, s, _ = x.size()

seq_len = x_mask.sum(-1).view(b)
r = s % self.stride
s -= r
x = x[:, :s, :]
seq_len = torch.where(seq_len > s, s, seq_len)
seq_len = seq_len // self.stride
new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride)
x = x.view(b, s // self.stride, self.idim * self.stride)
_, pos_emb = self.pos_enc_class(x, offset)
x = self.norm(x)
x = self.out(x)
return x, pos_emb, new_mask.unsqueeze(1)
11 changes: 5 additions & 6 deletions wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
)
from wenet.efficient_conformer.subsampling import Conv2dSubsampling2
from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4
from wenet.transformer.embedding import (PositionalEncoding,
RelPositionalEncoding,
RopePositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from wenet.transformer.embedding import (
PositionalEncoding, RelPositionalEncoding, RopePositionalEncoding,
W2vbertPositionalEncoding, WhisperPositionalEncoding,
LearnablePositionalEncoding, NoPositionalEncoding)
from wenet.transformer.attention import (MultiHeadedAttention,
MultiHeadedCrossAttention,
RelPositionMultiHeadedAttention,
Expand Down Expand Up @@ -71,6 +69,7 @@
"embed_learnable_pe": LearnablePositionalEncoding,
"abs_pos_paraformer": ParaformerPositinoalEncoding,
'rope_pos': RopePositionalEncoding,
"w2vbert_pos": W2vbertPositionalEncoding,
}

WENET_ATTENTION_CLASSES = {
Expand Down
Loading