From 6692e28374999357c15fd0b60293bd556df56247 Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Tue, 20 Aug 2024 15:14:12 -0700 Subject: [PATCH] use self.config to check attn --- .../models/mbart/modeling_mbart.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b5ac80c48f6eb3..9455f21b2073ff 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -955,8 +955,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N embed_dim, ) self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" + self.config = config self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1044,9 +1043,9 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: + elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1140,8 +1139,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.d_model, ) self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" + self.config = config self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1262,10 +1260,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1282,9 +1280,9 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]