diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index bf4c83af4..13456ea1c 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,6 +19,7 @@ It is designed to be performant, correct, and light-weight. """ import logging +import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -117,3 +118,26 @@ def pre_model_load(self, cfg): modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + logging.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_cross_entropy: + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward