Skip to content

Commit

Permalink
deepseekv2 liger support
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Aug 27, 2024
1 parent 1e43660 commit 2075298
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
It is designed to be performant, correct, and light-weight.
"""
import logging
import sys

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
Expand Down Expand Up @@ -117,3 +118,26 @@ def pre_model_load(self, cfg):
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward

elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM

with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]

from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward

if cfg.liger_rope:
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward

0 comments on commit 2075298

Please sign in to comment.