Skip to content

Commit

Permalink
Feat: Add support for upstream FA2 (#626)
Browse files Browse the repository at this point in the history
* Feat: Add support for upstream FA2

* chore: add is_falcon_derived_model: true to examples

* chore: add config to readme for documentation

* feat: add extra model types

* fix: remove old falcon flash patch

* chore: pin transformers and accelerate
  • Loading branch information
NanoCode012 authored Sep 26, 2023
1 parent 5e5296a commit 19a600a
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 117 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ tokenizer_legacy:
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# used to identify if the model is falcon/llama based
is_falcon_derived_model:
is_llama_derived_model:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed
addict
evaluate
Expand Down
101 changes: 0 additions & 101 deletions src/axolotl/monkeypatch/falcon_attn_hijack_flash.py

This file was deleted.

16 changes: 16 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def normalize_config(cfg):
or (cfg.model_type and "llama" in cfg.model_type.lower())
)

# figure out if the model is falcon
cfg.is_falcon_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down
20 changes: 6 additions & 14 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,13 @@ def load_model(

replace_btlm_attn_with_flash_attn(cfg.base_model)

if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
)

replace_falcon_attn_with_flash_attn()

if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)

LOG.info("patching with flash attention")
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
Expand Down Expand Up @@ -213,6 +201,10 @@ def load_model(
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
Expand Down

0 comments on commit 19a600a

Please sign in to comment.