Skip to content

Commit

Permalink
use self.config to check attn
Browse files Browse the repository at this point in the history
  • Loading branch information
nbroad1881 committed Aug 20, 2024
1 parent acf2065 commit 6692e28
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,8 +955,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.config = config
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)

Expand Down Expand Up @@ -1044,9 +1043,9 @@ def forward(

# expand attention_mask
if attention_mask is not None:
if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down Expand Up @@ -1140,8 +1139,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.config = config

self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
Expand Down Expand Up @@ -1262,10 +1260,10 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
Expand All @@ -1282,9 +1280,9 @@ def forward(

# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down

0 comments on commit 6692e28

Please sign in to comment.