Skip to content

Commit

Permalink
zero out pad mask att
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 19, 2024
1 parent 6f1fb43 commit 48c83e6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 56 deletions.
50 changes: 0 additions & 50 deletions test/wenet/ssl/w2vbert/test_w2vbert.py

This file was deleted.

54 changes: 48 additions & 6 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torch
import pytest
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask

Expand Down Expand Up @@ -53,16 +56,55 @@ def test_sdpa(args):
num_decoding_left_chunks=-1)
output, cache = mha_module(q, k, v, mask=att_mask)

att_mask = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa(q,
k,
v,
mask=att_mask)
att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa(
q, k, v, mask=att_mask_bias)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
rtol=9e-4,
)
assert torch.allclose(cache, cache_with_sdpa)

torch.manual_seed(777)
mha_layer = TransformerEncoderLayer(
args['n_feat'],
mha_module,
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
)

torch.manual_seed(777)
mha_layer_with_sdpa = TransformerEncoderLayer(
args['n_feat'],
mha_module_with_sdpa,
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
)
mha_layer.eval()
mha_layer_with_sdpa.eval()
output, _, cache, _ = mha_layer(q, att_mask, None, mask)
output_with_sdpa, _, cache_with_sdpa, _ = mha_layer_with_sdpa(
q, att_mask_bias, None, mask)

print(output)
print(output_with_sdpa)
assert torch.allclose(
output,
output_with_sdpa,
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)
4 changes: 4 additions & 0 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def forward(
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
if mask_pad.size(2) > 0:
x_att = x_att * mask_pad.transpose(1, 2)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
Expand Down Expand Up @@ -204,6 +206,8 @@ def forward(
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
att_cache)
if mask_pad.size(2) > 0:
x_att = x_att * mask_pad.transpose(1, 2)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
Expand Down

0 comments on commit 48c83e6

Please sign in to comment.