From 5bdb49448d737feddb2598e0798b8bcad7495927 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 23 Sep 2023 14:19:40 +0900 Subject: [PATCH 1/6] Feat: Add support for upstream FA2 --- src/axolotl/utils/config.py | 11 +++++++++++ src/axolotl/utils/models.py | 8 ++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0487ff8..c54c00840 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -86,6 +86,17 @@ def normalize_config(cfg): or (cfg.model_type and "llama" in cfg.model_type.lower()) ) + # figure out if the model is falcon + cfg.is_falcon_derived_model = ( + ( + hasattr(model_config, "model_type") + and model_config.model_type == "RefinedWebModel" + ) + or cfg.is_falcon_derived_model + or "falcon" in cfg.base_model + or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) + ) + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 543a0e1a1..5bc3fce58 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -126,13 +126,13 @@ def load_model( replace_falcon_attn_with_flash_attn() - if cfg.is_llama_derived_model and cfg.flash_attention: + if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - LOG.info("patching with flash attention") + LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn(packed=cfg.sample_packing) elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( @@ -213,6 +213,10 @@ def load_model( bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) + # sample packing uses custom FA2 patch + if cfg.flash_attention and not cfg.sample_packing: + if cfg.is_llama_derived_model or cfg.is_falcon_derived_model: + model_kwargs["use_flash_attention_2"] = True try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM From 97d8aaf9f6a38aa2af27b197d8f63236d40a7c1b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 25 Sep 2023 19:50:52 +0900 Subject: [PATCH 2/6] chore: add is_falcon_derived_model: true to examples --- examples/falcon/config-7b-lora.yml | 1 + examples/falcon/config-7b-qlora.yml | 1 + examples/falcon/config-7b.yml | 1 + 3 files changed, 3 insertions(+) diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index a5cbdc00d..738068a47 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: true load_in_4bit: false gptq: false diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index 72b09b87d..554081fcb 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: false # enable 4bit for QLoRA load_in_4bit: true diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 46f4caff1..25e67a53b 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer +is_falcon_derived_model: true load_in_8bit: false load_in_4bit: false gptq: false From 354ee7c66747945036a9cc34188e0928d743f190 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 25 Sep 2023 19:52:56 +0900 Subject: [PATCH 3/6] chore: add config to readme for documentation --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 6ec81eed3..53185a2b4 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,10 @@ tokenizer_legacy: # this is reported to improve training speed on some models resize_token_embeddings_to_32x: +# used to identify if the model is falcon/llama based +is_falcon_derived_model: +is_llama_derived_model: + # whether you are training a 4-bit GPTQ quantized model gptq: true gptq_groupsize: 128 # group size From ea4e804636bfd953618f7529761952faed43db50 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 26 Sep 2023 03:16:56 +0900 Subject: [PATCH 4/6] feat: add extra model types --- src/axolotl/utils/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index c54c00840..9b861551f 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -90,7 +90,12 @@ def normalize_config(cfg): cfg.is_falcon_derived_model = ( ( hasattr(model_config, "model_type") - and model_config.model_type == "RefinedWebModel" + and model_config.model_type + in [ + "falcon", + "RefinedWebModel", + "RefinedWeb", + ] ) or cfg.is_falcon_derived_model or "falcon" in cfg.base_model From e44993eb70f92367be80ee40b7d1502fabfae328 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 26 Sep 2023 03:17:11 +0900 Subject: [PATCH 5/6] fix: remove old falcon flash patch --- .../monkeypatch/falcon_attn_hijack_flash.py | 101 ------------------ src/axolotl/utils/models.py | 12 --- 2 files changed, 113 deletions(-) delete mode 100644 src/axolotl/monkeypatch/falcon_attn_hijack_flash.py diff --git a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py b/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py deleted file mode 100644 index ed11c5523..000000000 --- a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Flash Attention monkey patch for Falcon - -copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py -""" - -from typing import Optional, Tuple - -import torch -import transformers -from flash_attn import flash_attn_func - - -def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, # pylint: disable=unused-argument - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - use_cache: bool = False, - output_attentions: bool = False, # pylint: disable=unused-argument -): - fused_qkv = self.query_key_value( - hidden_states - ) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = ( - self.num_heads if self.new_decoder_architecture else self.num_kv_heads - ) - # 3 x [batch_size, seq_length, num_heads, head_dim] - ( - query_layer, - key_layer, - value_layer, - ) = self._split_heads( # pylint: disable=protected-access - fused_qkv - ) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape( - batch_size * self.num_heads, query_length, self.head_dim - ) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, query_length, self.head_dim - ) - - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - # unused - # _, kv_length, _ = key_layer.shape - if use_cache: - present = (key_layer, value_layer) - else: - present = None - # unused - # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) - query_layer_ = ( - query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - key_layer_ = ( - key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - value_layer_ = ( - value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - .transpose(1, 2) - .to(torch.bfloat16) - ) - - if alibi is not None: - raise ValueError("`alibi` is not supported when `use_flash_attn` is True") - - # below output will have shape (batch_size, seqlen, nheads, headdim) - attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True) - attn_output = attn_output.reshape( - batch_size, query_length, self.num_heads * self.head_dim - ) - output_tensor = self.dense(attn_output) - return output_tensor, present - - -def replace_falcon_attn_with_flash_attn(): - transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5bc3fce58..361440931 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -114,18 +114,6 @@ def load_model( replace_btlm_attn_with_flash_attn(cfg.base_model) - if hasattr(model_config, "model_type") and model_config.model_type in [ - "falcon", - "RefinedWebModel", - "RefinedWeb", - ]: - if cfg.flash_attention: - from axolotl.monkeypatch.falcon_attn_hijack_flash import ( - replace_falcon_attn_with_flash_attn, - ) - - replace_falcon_attn_with_flash_attn() - if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( From fd4cea2c824429a155bf32a31e98f515f47a3d27 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 26 Sep 2023 22:23:09 +0900 Subject: [PATCH 6/6] chore: pin transformers and accelerate --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5aba20b16..33a2157d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,9 @@ torch==2.0.1 auto-gptq packaging peft @ git+https://github.com/huggingface/peft.git -transformers @ git+https://github.com/huggingface/transformers.git +transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb bitsandbytes>=0.41.1 -accelerate @ git+https://github.com/huggingface/accelerate +accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 deepspeed addict evaluate