Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
vasqu committed Jun 1, 2024
1 parent de45320 commit edfc6ed
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
import os
import warnings
from dataclasses import dataclass
from packaging import version
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -38,7 +39,6 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -580,15 +580,15 @@ def __init__(self, *args, **kwargs):
# Adapted from GPT2Attention.forward and following other sdpa implementations
# such as transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
Expand Down Expand Up @@ -641,7 +641,7 @@ def forward(
present = (key, value)

if attention_mask is not None and not is_cross_attention:
attention_mask = attention_mask[:, :, :, : tgt_len]
attention_mask = attention_mask[:, :, :, :tgt_len]

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
Expand All @@ -658,7 +658,7 @@ def forward(
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal
is_causal=is_causal,
)

# Reshape outputs
Expand Down Expand Up @@ -689,11 +689,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


GPT2_ATTENTION_CLASSES = {
"eager": GPT2Attention,
"flash_attention_2": GPT2FlashAttention2,
"sdpa": GPT2SdpaAttention
}
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}


class GPT2Block(nn.Module):
Expand Down Expand Up @@ -1190,9 +1186,7 @@ def forward(
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if _use_sdpa:
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask,
dtype=inputs_embeds.dtype,
tgt_len=seq_len
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=seq_len
)
else:
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
Expand Down

0 comments on commit edfc6ed

Please sign in to comment.