Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper, Bart, MBart] Add Flash Attention 2 #27203

Merged
merged 18 commits into from
Nov 1, 2023
284 changes: 272 additions & 12 deletions src/transformers/models/bart/modeling_bart.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def forward(
return outputs


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BigBirdPegasusDecoder
# Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder
class BigBirdPegasusDecoderAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -1185,12 +1185,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BigBirdPegasusConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -1199,6 +1202,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BioGptConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -104,6 +107,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
24 changes: 19 additions & 5 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -118,6 +121,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down Expand Up @@ -248,15 +252,21 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot
BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BlenderbotAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -317,28 +327,32 @@ def forward(
return outputs


# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BlenderbotAttention(
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotAttention(
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotSmallConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -115,6 +118,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down Expand Up @@ -245,15 +249,18 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BlenderbotSmallAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"

self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -314,28 +321,35 @@ def forward(
return outputs


# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall
BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}


# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model

self.self_attn = BlenderbotSmallAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotSmallAttention(
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[Data2VecAudioConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -344,6 +347,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[GPTSanJapaneseConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -384,6 +387,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[HubertConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -410,6 +413,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/informer/modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,15 @@ def __init__(
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[InformerConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand All @@ -301,6 +304,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand Down
Loading
Loading