From 15c0b62807e51bd00d38e772a222e7b773d9e7d0 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 12 Apr 2023 04:55:37 -0700 Subject: [PATCH] Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265) Summary: When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function. Pull Request resolved: https://github.com/pytorch/audio/pull/3265 Reviewed By: hwangjeff Differential Revision: D44901093 Pulled By: nateanl fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70 --- torchaudio/models/wav2vec2/wavlm_attention.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchaudio/models/wav2vec2/wavlm_attention.py b/torchaudio/models/wav2vec2/wavlm_attention.py index be81fc9702..fafddfeb95 100644 --- a/torchaudio/models/wav2vec2/wavlm_attention.py +++ b/torchaudio/models/wav2vec2/wavlm_attention.py @@ -183,6 +183,17 @@ def forward( attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len)) + if attn_mask_rel_pos is not None and key_padding_mask is not None: + key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1) + key_padding_mask = torch.nn.functional._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos), + other_name="", + target_type=query.dtype, + ) + if attn_mask_rel_pos is not None and key_padding_mask is not None: + attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias) query, key, value = query_projected.chunk(3, -1) shape = (bsz, seq_len, self.num_heads, self.head_dim)