Skip to content

Commit

Permalink
Merge key_padding_mask into attn_mask_rel_pos in WavLM (pytorch#3265)
Browse files Browse the repository at this point in the history
Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.

Pull Request resolved: pytorch#3265

Reviewed By: hwangjeff

Differential Revision: D44901093

Pulled By: nateanl

fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
  • Loading branch information
nateanl committed Apr 12, 2023
1 parent 5c6b835 commit 15c0b62
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchaudio/models/wav2vec2/wavlm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ def forward(

attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))

if attn_mask_rel_pos is not None and key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
key_padding_mask = torch.nn.functional._canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
other_name="",
target_type=query.dtype,
)
if attn_mask_rel_pos is not None and key_padding_mask is not None:
attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
query, key, value = query_projected.chunk(3, -1)
shape = (bsz, seq_len, self.num_heads, self.head_dim)
Expand Down

0 comments on commit 15c0b62

Please sign in to comment.