From 248bf90f89721a13bf869b7566c398ebc3380d49 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 2 Aug 2023 20:15:03 +0000 Subject: [PATCH 1/4] ensure flash-attn fixes happen in both adapter/lora modes, and use torch_dtype --- src/axolotl/utils/models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c199a16e..d4bda130c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -331,6 +331,14 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + if cfg.flash_attention: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) + model, lora_config = load_adapter(model, cfg, adapter) if cfg.ddp and not load_in_8bit: @@ -407,14 +415,6 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch.float16) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch.float16) - model.print_trainable_parameters() return model, peft_config From 312a9fad079f33ed4d371863b8c36b2b40c7a214 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 17:20:49 +0000 Subject: [PATCH 2/4] move flash-attn monkey patch alongside the others --- .../{flash_attn.py => monkeypatch/llama_attn_hijack_flash.py} | 0 src/axolotl/utils/models.py | 4 +++- 2 files changed, 3 insertions(+), 1 deletion(-) rename src/axolotl/{flash_attn.py => monkeypatch/llama_attn_hijack_flash.py} (100%) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py similarity index 100% rename from src/axolotl/flash_attn.py rename to src/axolotl/monkeypatch/llama_attn_hijack_flash.py diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d4bda130c..23d7716a0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -92,7 +92,9 @@ def load_model( if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and not cfg.inference: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) LOG.info("patching with flash attention") replace_llama_attn_with_flash_attn() From 78b9efb7f4ae508ccaa2954e6b326d7bb938945c Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 19:19:39 +0000 Subject: [PATCH 3/4] scope flash-attn+qlora fix correctly, scope to llama, add comment --- src/axolotl/utils/models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 23d7716a0..0dcb4cd77 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -333,13 +333,15 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): + # LlamaRMSNorm layers are in fp32 after kit call, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) model, lora_config = load_adapter(model, cfg, adapter) From 2eda9e02a9d15a7a3f92b41f257d9844d72fc220 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 21:04:12 +0000 Subject: [PATCH 4/4] fix typo --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0dcb4cd77..253bdcbd8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -333,7 +333,7 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - # LlamaRMSNorm layers are in fp32 after kit call, so we need to + # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. if cfg.flash_attention and cfg.is_llama_derived_model: for name, module in model.named_modules():