From f28600f9a3a9a6f81c4cff55308ddffc4e40b25c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 3 Apr 2024 22:22:24 -0400 Subject: [PATCH 1/2] add bias --- src/transformers/models/llama/configuration_llama.py | 2 ++ src/transformers/models/llama/modeling_llama.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 6d0f68162cce43..3bbd80552cc232 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -137,6 +137,7 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, + mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size @@ -161,6 +162,7 @@ def __init__( self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8d0baf63c7b3fe..95a4ac8ccdb326 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -214,9 +214,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): From a41eca11dfd9472253b6379fce2abeaadcefe91b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Fri, 3 May 2024 14:14:03 +0530 Subject: [PATCH 2/2] fix quality --- src/transformers/models/cohere/modeling_cohere.py | 1 - src/transformers/models/llama/configuration_llama.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a184fe4450f3ef..ce6fe29b18859c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -161,7 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 3bbd80552cc232..66d668be882658 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -94,10 +94,12 @@ class LlamaConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. ```python >>> from transformers import LlamaModel, LlamaConfig