From 5c6b83536e4a2f4593c17dd9d770d956a4a46f16 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Mon, 10 Apr 2023 16:49:43 -0700 Subject: [PATCH] Use scaled_dot_product_attention in WavLM attention (#3252) Summary: Fix https://github.com/pytorch/audio/issues/3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: https://github.com/pytorch/audio/pull/3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665 --- torchaudio/models/wav2vec2/wavlm_attention.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/torchaudio/models/wav2vec2/wavlm_attention.py b/torchaudio/models/wav2vec2/wavlm_attention.py index 4fc723f78a..be81fc9702 100644 --- a/torchaudio/models/wav2vec2/wavlm_attention.py +++ b/torchaudio/models/wav2vec2/wavlm_attention.py @@ -73,6 +73,7 @@ def __init__( self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.dropout = dropout self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True) self.gru_rel_pos = gru_rel_pos @@ -165,7 +166,7 @@ def forward( if self.rel_attn_embed is not None and position_bias is None: position_bias = self.compute_bias(seq_len, seq_len) - position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1) attn_mask_rel_pos: Optional[Tensor] = None if position_bias is not None: @@ -178,11 +179,25 @@ def forward( self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False) ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 - attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias - - attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) - - attn_output, _ = self.attention( - query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False + attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len)) + + query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias) + query, key, value = query_projected.chunk(3, -1) + shape = (bsz, seq_len, self.num_heads, self.head_dim) + query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + dropout = self.dropout if self.training else 0.0 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask_rel_pos, + dropout_p=dropout, + is_causal=False, ) + attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim) + attn_output = self.attention.out_proj(attn_output) return attn_output, position_bias