diff --git a/requirements.txt b/requirements.txt index e9d17ea2f..accad47d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,8 +13,8 @@ flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 -torch==2.1.2 -torchaudio==2.1.2 +torch>=2.1.2 +torchaudio>=2.1.2 tqdm deepspeed<0.13.0 librosa diff --git a/test/wenet/ssl/w2vbert/test_w2vbert.py b/test/wenet/ssl/w2vbert/test_w2vbert.py deleted file mode 100644 index e26052d23..000000000 --- a/test/wenet/ssl/w2vbert/test_w2vbert.py +++ /dev/null @@ -1,50 +0,0 @@ -from pathlib import Path -import pytest -import torch -import torchaudio - -from wenet.dataset import processor - -try: - import fairseq2 # noqa - from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter - from fairseq2.memory import MemoryBlock -except ImportError: - import os - os.system('pip install --no-input fairseq2') - import fairseq2 # noqa - from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter - from fairseq2.memory import MemoryBlock - - -@pytest.mark.parametrize( - "wav_file", - [ - # "test/resources/aishell-BAC009S0724W0121.wav", - "test/resources/librispeech-1995-1837-0001.wav", - ]) -def test_w2vbert_fbank(wav_file): - fbank_convert = WaveformToFbankConverter( - num_mel_bins=80, - waveform_scale=2**15, - channel_last=True, - standardize=True, - ) - audio_decoder = AudioDecoder(dtype=torch.float32) - with Path(wav_file).open("rb") as fb: - block = MemoryBlock(fb.read()) - decode_audio = audio_decoder(block) - w2vbert_waveform = decode_audio['waveform'] - w2vbert_mat = fbank_convert(decode_audio)['fbank'] - - wenet_waveform, _ = torchaudio.load(wav_file) - fbank_args = { - "num_mel_bins": 80, - "frame_length": 25, - "frame_shift": 10, - "dither": 0.0, - } - sample = {'sample_rate': 16000, "wav": wenet_waveform, 'key': wav_file} - wenet_mat = processor.compute_w2vbert_fbank(sample, **fbank_args)['feat'] - assert torch.allclose(w2vbert_waveform.transpose(0, 1), wenet_waveform) - assert torch.allclose(w2vbert_mat, wenet_mat, atol=9e-5, rtol=9e-4) diff --git a/test/wenet/transformer/test_attention.py b/test/wenet/transformer/test_attention.py new file mode 100644 index 000000000..f1ea491db --- /dev/null +++ b/test/wenet/transformer/test_attention.py @@ -0,0 +1,111 @@ +import torch +import pytest +from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES + +from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask + + +@pytest.mark.parametrize("args", [ + { + "n_feat": 256, + "n_head": 4, + "dropout_rate": 0.0 + }, + { + "n_feat": 512, + "n_head": 8, + "dropout_rate": 0.0 + }, + { + "n_feat": 1280, + "n_head": 20, + "dropout_rate": 0.0 + }, + { + "n_feat": 512, + "n_head": 4, + "dropout_rate": 0.0 + }, +]) +def test_sdpa(args): + torch.manual_seed(777) + mha_module = MultiHeadedAttention(use_sdpa=False, **args) + torch.manual_seed(777) + mha_module_with_sdpa = MultiHeadedAttention(use_sdpa=True, **args) + mha_module.eval() + mha_module_with_sdpa.eval() + + q = torch.rand(10, 100, args['n_feat']) + k = torch.rand(10, 100, args['n_feat']) + v = torch.rand(10, 100, args['n_feat']) + input_lens = torch.tensor([100, 90, 80, 79, 60, 51, 40, 30, 10, 5]) + mask = make_non_pad_mask(input_lens).unsqueeze(1) + att_mask = add_optional_chunk_mask(q, + mask, + use_dynamic_chunk=True, + decoding_chunk_size=0, + static_chunk_size=0, + use_dynamic_left_chunk=True, + num_decoding_left_chunks=-1) + output, cache = mha_module(q, k, v, mask=att_mask) + + att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min + output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa( + q, k, v, mask=att_mask_bias) + assert torch.allclose( + output * mask.transpose(1, 2), + output_with_sdpa * mask.transpose(1, 2), + atol=9e-7, + ) + assert torch.allclose(cache, cache_with_sdpa) + + n_blocks = 12 + torch.manual_seed(777) + mha_layers = [ + TransformerEncoderLayer( + args['n_feat'], + MultiHeadedAttention(use_sdpa=False, **args), + PositionwiseFeedForward( + args['n_feat'], + 2048, + 0.0, + WENET_ACTIVATION_CLASSES['swish'](), + ), + 0.0, + normalize_before=True, + ) for _ in range(n_blocks) + ] + + torch.manual_seed(777) + mha_layers_with_sdpa = [ + TransformerEncoderLayer( + args['n_feat'], + MultiHeadedAttention(use_sdpa=True, **args), + PositionwiseFeedForward( + args['n_feat'], + 2048, + 0.0, + WENET_ACTIVATION_CLASSES['swish'](), + ), + 0.0, + normalize_before=True, + ) for _ in range(n_blocks) + ] + + for i in range(n_blocks): + output, _, cache, _ = mha_layers[i](q, att_mask, None, mask) + output_with_sdpa, _, cache_with_sdpa, _ = mha_layers_with_sdpa[i]( + q, att_mask_bias, None, mask) + + assert torch.allclose( + output * mask.transpose(1, 2), + output_with_sdpa * mask.transpose(1, 2), + atol=9e-7, + rtol=9e-4, + ) + # assert torch.allclose(cache, cache_with_sdpa) + + q = output diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 3b215c10b..836a22cbd 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -36,7 +36,8 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, - key_bias: bool = True): + key_bias: bool = True, + use_sdpa: bool = False): """Construct an MultiHeadedAttention object.""" super().__init__() assert n_feat % n_head == 0 @@ -49,6 +50,9 @@ def __init__(self, self.linear_out = nn.Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) + self.use_sdpa = use_sdpa + self.dropout_rate = dropout_rate + def forward_qkv( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -192,8 +196,22 @@ def forward( # non-trivial to calculate `next_cache_start` here. new_cache = torch.cat((k, v), dim=-1) - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask), new_cache + if not self.use_sdpa: + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + else: + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask.unsqueeze(1), + dropout_p=self.dropout_rate, + scale=1 / math.sqrt(self.d_k), + ) + output = (output.transpose(1, 2).contiguous().view( + query.size(0), -1, + self.h * self.d_k)) # (batch, time1, d_model) + return self.linear_out(output), new_cache class RelPositionMultiHeadedAttention(MultiHeadedAttention): diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 02c098c8c..ec467ee43 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -26,6 +26,7 @@ WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, ) +from wenet.utils.common import mask_to_bias from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -73,6 +74,7 @@ def __init__( activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, + use_sdpa: bool = False, ): super().__init__() attention_dim = encoder_output_size @@ -98,10 +100,10 @@ def __init__( attention_dim, WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, attention_dim, - self_attention_dropout_rate, key_bias), + self_attention_dropout_rate, key_bias, use_sdpa), WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, attention_dim, src_attention_dropout_rate, - key_bias) if src_attention else None, + key_bias, use_sdpa) if src_attention else None, PositionwiseFeedForward(attention_dim, linear_units, dropout_rate, activation), dropout_rate, @@ -111,6 +113,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.tie_word_embedding = tie_word_embedding + self.use_sdpa = use_sdpa def forward( self, @@ -152,6 +155,10 @@ def forward( device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m + if self.use_sdpa: + tgt_mask = mask_to_bias(tgt_mask, tgt.dtype) + memory_mask = mask_to_bias(memory_mask, memory_mask.dtype) + x, _ = self.embed(tgt) if self.gradient_checkpointing and self.training: x = self.forward_layers_checkpointed(x, tgt_mask, memory, @@ -290,6 +297,7 @@ def __init__( key_bias: bool = True, gradient_checkpointing: bool = False, tie_word_embedding: bool = False, + use_sdpa: bool = False, ): super().__init__() @@ -309,7 +317,8 @@ def __init__( normalize_before, key_bias=key_bias, gradient_checkpointing=gradient_checkpointing, - tie_word_embedding=tie_word_embedding) + tie_word_embedding=tie_word_embedding, + use_sdpa=use_sdpa) self.right_decoder = TransformerDecoder( vocab_size, @@ -326,7 +335,8 @@ def __init__( normalize_before, key_bias=key_bias, gradient_checkpointing=gradient_checkpointing, - tie_word_embedding=tie_word_embedding) + tie_word_embedding=tie_word_embedding, + use_sdpa=use_sdpa) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 894caf59d..b41188fbb 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -31,6 +31,7 @@ ) from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask +from wenet.utils.common import mask_to_bias class BaseEncoder(torch.nn.Module): @@ -53,6 +54,7 @@ def __init__( global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, gradient_checkpointing: bool = False, + use_sdpa: bool = False, ): """ Args: @@ -84,6 +86,7 @@ def __init__( key_bias: whether use bias in attention.linear_k, False for whisper models. gradient_checkpointing: rerunning a forward-pass segment for each checkpointed segment during backward. + use_sdpa: whether to use SDPA, currently only support transformer for now """ super().__init__() self._output_size = output_size @@ -103,6 +106,7 @@ def __init__( self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk self.gradient_checkpointing = gradient_checkpointing + self.use_sdpa = use_sdpa def output_size(self) -> int: return self._output_size @@ -149,6 +153,8 @@ def forward( decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) + if self.use_sdpa: + chunk_masks = mask_to_bias(chunk_masks, xs.dtype) if self.gradient_checkpointing and self.training: xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad) @@ -355,6 +361,7 @@ def __init__( key_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, + use_sdpa: bool = False, ): """ Construct TransformerEncoder @@ -365,7 +372,8 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, - use_dynamic_left_chunk, gradient_checkpointing) + use_dynamic_left_chunk, gradient_checkpointing, + use_sdpa) activation = WENET_ACTIVATION_CLASSES[activation_type]() self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( @@ -373,7 +381,7 @@ def __init__( WENET_ATTENTION_CLASSES["selfattn"](attention_heads, output_size, attention_dropout_rate, - key_bias), + key_bias, use_sdpa), PositionwiseFeedForward(output_size, linear_units, dropout_rate, activation), dropout_rate, normalize_before) for _ in range(num_blocks) @@ -433,7 +441,7 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, - use_dynamic_left_chunk, gradient_checkpointing) + use_dynamic_left_chunk, gradient_checkpointing, False) activation = WENET_ACTIVATION_CLASSES[activation_type]() # self-attention module definition diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 7c5efe7c4..958442906 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -19,7 +19,8 @@ import torch from torch.nn.utils.rnn import pad_sequence -from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens) +from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens, + mask_to_bias) from wenet.utils.ctc_utils import remove_duplicates_and_blank from wenet.utils.mask import (make_pad_mask, mask_finished_preds, mask_finished_scores, subsequent_mask) @@ -289,6 +290,8 @@ def attention_beam_search( ]).unsqueeze(1).to(device) # (B*N, 1) end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) cache: Optional[List[torch.Tensor]] = None + if model.decoder.use_sdpa: + encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype) # 2. Decoder forward step by step for i in range(prefix_len, maxlen + 1): # Stop if all batch and all beam produce eos @@ -297,6 +300,8 @@ def attention_beam_search( # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) + if model.decoder.use_sdpa: + hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype) # logp: (B*N, vocab) logp, cache = model.decoder.forward_one_step(encoder_out, encoder_mask, hyps, hyps_mask, cache) diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 13f216674..c5398de02 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -307,3 +307,33 @@ def log_add(*args) -> float: a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp + + +def get_dtype_min( + dtype: torch.dtype, + eps16: float = torch.finfo(torch.float16).min, + eps32: float = torch.finfo(torch.float32).min, + eps64: float = torch.finfo(torch.float64).min, + epsbf16: float = torch.finfo(torch.bfloat16).min, +): + if dtype == torch.float16: + return eps16 + elif dtype == torch.float32: + return eps32 + elif dtype == torch.float64: + return eps64 + elif dtype == torch.bfloat16: + return epsbf16 + else: + raise RuntimeError(f"expected x to be floating-point, got {dtype}") + + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return mask + assert mask.dtype == torch.bool + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * get_dtype_min(dtype) + return mask