From 803416e65ceac39741b5bb0001acc45648a65984 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 20 Feb 2024 00:30:30 +0800 Subject: [PATCH] support attention mask bias in encocder --- test/wenet/transformer/test_attention.py | 118 ++++++++++++----------- wenet/transformer/encoder_layer.py | 4 - 2 files changed, 62 insertions(+), 60 deletions(-) diff --git a/test/wenet/transformer/test_attention.py b/test/wenet/transformer/test_attention.py index 13f5fce2c0..99cc415a9f 100644 --- a/test/wenet/transformer/test_attention.py +++ b/test/wenet/transformer/test_attention.py @@ -12,28 +12,30 @@ from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask -@pytest.mark.parametrize("args", [ - { - "n_feat": 256, - "n_head": 4, - "dropout_rate": 0.0 - }, - { - "n_feat": 512, - "n_head": 8, - "dropout_rate": 0.0 - }, - { - "n_feat": 1280, - "n_head": 20, - "dropout_rate": 0.0 - }, - { - "n_feat": 512, - "n_head": 4, - "dropout_rate": 0.0 - }, -]) +@pytest.mark.parametrize( + "args", + [ + { + "n_feat": 256, + "n_head": 4, + "dropout_rate": 0.0 + }, + # { + # "n_feat": 512, + # "n_head": 8, + # "dropout_rate": 0.0 + # }, + # { + # "n_feat": 1280, + # "n_head": 20, + # "dropout_rate": 0.0 + # }, + # { + # "n_feat": 512, + # "n_head": 4, + # "dropout_rate": 0.0 + # }, + ]) def test_sdpa(args): torch.manual_seed(777) mha_module = MultiHeadedAttention(use_sdpa=False, **args) @@ -59,7 +61,6 @@ def test_sdpa(args): att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa( q, k, v, mask=att_mask_bias) - assert torch.allclose( output * mask.transpose(1, 2), output_with_sdpa * mask.transpose(1, 2), @@ -67,44 +68,49 @@ def test_sdpa(args): ) assert torch.allclose(cache, cache_with_sdpa) + n_blocks = 12 torch.manual_seed(777) - mha_layer = TransformerEncoderLayer( - args['n_feat'], - mha_module, - PositionwiseFeedForward( + mha_layers = [ + TransformerEncoderLayer( args['n_feat'], - 2048, + MultiHeadedAttention(use_sdpa=False, **args), + PositionwiseFeedForward( + args['n_feat'], + 2048, + 0.0, + WENET_ACTIVATION_CLASSES['swish'](), + ), 0.0, - WENET_ACTIVATION_CLASSES['swish'](), - ), - 0.0, - normalize_before=True, - ) + normalize_before=True, + ) for _ in range(n_blocks) + ] torch.manual_seed(777) - mha_layer_with_sdpa = TransformerEncoderLayer( - args['n_feat'], - mha_module_with_sdpa, - PositionwiseFeedForward( + mha_layers_with_sdpa = [ + TransformerEncoderLayer( args['n_feat'], - 2048, + MultiHeadedAttention(use_sdpa=True, **args), + PositionwiseFeedForward( + args['n_feat'], + 2048, + 0.0, + WENET_ACTIVATION_CLASSES['swish'](), + ), 0.0, - WENET_ACTIVATION_CLASSES['swish'](), - ), - 0.0, - normalize_before=True, - ) - mha_layer.eval() - mha_layer_with_sdpa.eval() - output, _, cache, _ = mha_layer(q, att_mask, None, mask) - output_with_sdpa, _, cache_with_sdpa, _ = mha_layer_with_sdpa( - q, att_mask_bias, None, mask) + normalize_before=True, + ) for _ in range(n_blocks) + ] - print(output) - print(output_with_sdpa) - assert torch.allclose( - output, - output_with_sdpa, - atol=9e-7, - ) - assert torch.allclose(cache, cache_with_sdpa) + for i in range(n_blocks): + output, _, cache, _ = mha_layers[i](q, att_mask, None, mask) + output_with_sdpa, _, cache_with_sdpa, _ = mha_layers_with_sdpa[i]( + q, att_mask_bias, None, mask) + + assert torch.allclose( + output * mask.transpose(1, 2), + output_with_sdpa * mask.transpose(1, 2), + atol=9e-7, + ) + assert torch.allclose(cache, cache_with_sdpa) + + q = output diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index c517864fd4..aafcec4123 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -91,8 +91,6 @@ def forward( if self.normalize_before: x = self.norm1(x) x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) - if mask_pad.size(2) > 0: - x_att = x_att * mask_pad.transpose(1, 2) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm1(x) @@ -206,8 +204,6 @@ def forward( x = self.norm_mha(x) x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) - if mask_pad.size(2) > 0: - x_att = x_att * mask_pad.transpose(1, 2) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x)