Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
fix issues of legacy gemma models
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 17, 2024
1 parent e5b5d63 commit edbb95c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion mlora/common/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def forward(

# MixLoRA
def init_moe_weight(
self, args: LLMModelConfig, config: MixConfig, gate: Optional[torch.Tensor] = None
self,
args: LLMModelConfig,
config: MixConfig,
gate: Optional[torch.Tensor] = None,
):
self.moes_[config.adapter_name] = moe_layer_factory(args, config)
if gate is None:
Expand Down
9 changes: 8 additions & 1 deletion mlora/models/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
return data * normalizer


def _patch_hidden_act(config: modeling_gemma.GemmaConfig) -> str:
if hasattr(config, "hidden_activation") and config.hidden_activation is not None:
return config.hidden_activation
else:
return config.hidden_act


class GemmaForCausalLM(LlamaForCausalLM):
def __init__(self, config: LlamaConfig) -> None:
super().__init__(config)
Expand All @@ -71,7 +78,7 @@ def from_pretrained(
n_layers_=llm_config.num_hidden_layers,
n_heads_=llm_config.num_attention_heads,
n_kv_heads_=llm_config.num_key_value_heads,
hidden_act_=llm_config.hidden_activation,
hidden_act_=_patch_hidden_act(llm_config),
rms_norm_eps_=llm_config.rms_norm_eps,
max_seq_len_=llm_config.max_position_embeddings,
rope_theta_=llm_config.rope_theta,
Expand Down

0 comments on commit edbb95c

Please sign in to comment.