From 4535b0bb6dd6e00a0d6ad878c3f50713f71082d0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 16:41:16 +0200 Subject: [PATCH 1/5] softcapping --- src/transformers/models/gemma2/configuration_gemma2.py | 2 ++ src/transformers/models/gemma2/modeling_gemma2.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 47207d7ca124..760707904249 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -116,6 +116,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, + attn_logit_softcapping=50.00, query_pre_attn_scalar=224, sliding_window=4096, **kwargs, @@ -135,6 +136,7 @@ def __init__( self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.attn_logit_softcapping = attn_logit_softcapping super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 09ce72c8b1b2..e6f4c5e9da96 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -260,6 +260,11 @@ def forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) From 6ee8e53da8b1206a9a80a0bf52b0ee2db4ec8915 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 16:43:59 +0200 Subject: [PATCH 2/5] soft cap before the mask --- src/transformers/models/gemma2/modeling_gemma2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 94fe54869280..6b2b47b5159e 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -256,15 +256,15 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - if self.config.attn_logit_softcapping is not None: attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping - + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) From 60935ba823002aab186858b19efe72ef5518aa2d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:02:14 +0200 Subject: [PATCH 3/5] style --- src/transformers/models/gemma2/configuration_gemma2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 760707904249..b11f4c621a82 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. From 33ce5958a52869b9e5a45ec0b479467eb7ea42f9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:08:08 +0200 Subject: [PATCH 4/5] ... --- src/transformers/models/gemma2/configuration_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index b11f4c621a82..63cb04a0b5a0 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -78,7 +78,7 @@ class Gemma2Config(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the attention scores. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. From 957c98e84087534f9d968edcd2d8f02d782af93f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:09:50 +0200 Subject: [PATCH 5/5] super nit --- src/transformers/models/gemma2/configuration_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 63cb04a0b5a0..7da541207bfe 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -117,7 +117,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, - attn_logit_softcapping=50.00, + attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, **kwargs,