Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral: Sliding Window Attention with Flash Attention and Sample Packing #732

Merged
merged 10 commits into from
Oct 16, 2023

Conversation

casper-hansen
Copy link
Collaborator

@casper-hansen casper-hansen commented Oct 14, 2023

Main benefits of this PR:

  • Loss is reduced 5x on long-context datasets due to passing window_size to Flash Attention.
  • Memory usage is reduced by 3GB due to the usage of a sliding window mask.
Memory usage

Memory usage with SWA. The conclusion is that you save 3GB when using a sliding window mask.

  • _prepare_decoder_attention_mask with window_size=(4096, 4096) parameter to flash attention.
    • PR, with sliding window mask: 4.851GB (+24.606GB cache, +0.781GB misc)
    • PR, without sliding window mask: 4.851GB (+27.778GB cache, +0.781GB misc)
  • _prepare_decoder_attention_mask with window_size=(-1, -1) parameter to flash attention.
    • PR, with sliding window mask: 4.851GB (+24.606GB cache, +0.781GB misc)
    • Main, without sliding window mask: 4.851GB (+27.778GB cache, +0.781GB misc)
Long context (casperhansen/longalpaca_1k_test)

I test with a long context dataset, minimum 16k tokens and maximum 32k tokens. Minimum 48GB VRAM needed to run this.

Results after a few steps:

  • PR: Loss starts at 1.9, goes down to 1.54 after 4 steps, 1.396 after 19 steps
  • Main: Loss starts at 9.98, goes down to 9.11 after 4 steps
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: casperhansen/longalpaca_1k_test
    type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 32768
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_mode: 
wandb_project: 
wandb_entity: 
wandb_watch: 
wandb_run_id: 
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"
Short context (mhenrichsen/alpaca_2k_test)

Loss on short-context datasets is tested to be the same.

Used default config in examples/mistral/qlora.yml.

image
image

Other notes:

  • attention_mask and sliding_window_mask are not broadcastable in the first iteration after eval loss. However, this is only the case when wandb is enabled. This error is handled by attention_mask.shape[0] != 1 so that it does not trigger.
  • I tried to use _expand_mask and it did not work with Flash Attention. I tried other methods too, but same problem.

@casper-hansen casper-hansen marked this pull request as ready for review October 15, 2023 12:31
@winglian winglian merged commit a045db0 into axolotl-ai-cloud:main Oct 16, 2023
4 checks passed
@casper-hansen casper-hansen deleted the mistral_fa_swa branch October 19, 2023 17:28
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
…king (axolotl-ai-cloud#732)

* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
@TJ-Solergibert
Copy link
Contributor

Hi! I'm trying to fine-tune Mistral 7b with a chat dataset, packing, train on competitions only and FA. I got spammed with the warning skipping sliding window mask, not broadcastable with attention mask. I don't understand why this is happening in the first iteration after eval loss and only when wandb is activated. Should I worry for the performance of the model?

Thanks,

Toni

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants