diff --git a/torchaudio/models/wav2vec2/components.py b/torchaudio/models/wav2vec2/components.py index d42cf1bd3c..f953adfeaf 100644 --- a/torchaudio/models/wav2vec2/components.py +++ b/torchaudio/models/wav2vec2/components.py @@ -262,7 +262,7 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads - self.dropout = torch.nn.Dropout(dropout) + self.dropout = dropout self.head_dim = head_dim self.scaling = self.head_dim**-0.5 @@ -304,25 +304,14 @@ def forward( shape = (batch_size, length, self.num_heads, self.head_dim) q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd - k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L + k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd - - # scale down q to avoid value overflow. - weights = (self.scaling * q) @ k # B, nH, L, L - if attention_mask is not None: - weights += attention_mask - # subtracting a constant value from the tensor won't change the output of softmax. - # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. - # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 - weights = weights - weights.max(dim=-1, keepdim=True)[0] - - weights = torch.nn.functional.softmax(weights, dim=-1) - weights = self.dropout(weights) - - output = weights @ v # B, nH, L, Hd - output = output.transpose(2, 1).reshape(batch_size, length, embed_dim) - - output = self.out_proj(output) + dropout = self.dropout if self.training else 0.0 + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + output = self.out_proj(attn_output) return output, None # Necessary for compatibility with WavLMSelAttention