Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash-attn + qlora not working with llama models #336

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
22 changes: 13 additions & 9 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def load_model(

if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)

LOG.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
Expand Down Expand Up @@ -331,6 +333,16 @@ def load_model(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)

# LlamaRMSNorm layers are in fp32 after kit call, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)

model, lora_config = load_adapter(model, cfg, adapter)

if cfg.ddp and not load_in_8bit:
Expand Down Expand Up @@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
else:
model = get_peft_model(model, peft_config)

if cfg.flash_attention:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float16)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float16)

model.print_trainable_parameters()

return model, peft_config
Expand Down