Skip to content

Commit

Permalink
Fix DPT /Dinov2 sdpa regression on main (huggingface#33660)
Browse files Browse the repository at this point in the history
* fallback to eager if output attentions.

* fix copies
  • Loading branch information
molbap authored and amyeroberts committed Oct 2, 2024
1 parent 2e90d56 commit fef3be0
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
Expand All @@ -240,6 +239,16 @@ def __init__(self, config: Dinov2Config) -> None:
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
)

mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
Expand Down

0 comments on commit fef3be0

Please sign in to comment.