Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support for upstream FA2 #626

Merged
merged 6 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ tokenizer_legacy:
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:

# used to identify if the model is falcon/llama based
is_falcon_derived_model:
is_llama_derived_model:

# whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
Expand Down
1 change: 1 addition & 0 deletions examples/falcon/config-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed
addict
evaluate
Expand Down
101 changes: 0 additions & 101 deletions src/axolotl/monkeypatch/falcon_attn_hijack_flash.py

This file was deleted.

16 changes: 16 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def normalize_config(cfg):
or (cfg.model_type and "llama" in cfg.model_type.lower())
)

# figure out if the model is falcon
cfg.is_falcon_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down
20 changes: 6 additions & 14 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,13 @@ def load_model(

replace_btlm_attn_with_flash_attn(cfg.base_model)

if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
)

replace_falcon_attn_with_flash_attn()

if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)

LOG.info("patching with flash attention")
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
Expand Down Expand Up @@ -213,6 +201,10 @@ def load_model(
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
Expand Down