You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
transformers version: 4.36.2
flash-attn: 2.5.2 flash_attn-2.5.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64
Platform: linux_x86_64 cp310 ubuntu-22.04
Python version: 3.10
Huggingface_hub version: 0.20.3
Safetensors version: 0.4.2
Accelerate version: 0.26.1
Accelerate config: not found
PyTorch version (GPU?): 2.1.2 (True) torch-2.1.2-cu118-cp310-cp310-linux_x86_64.whl
Tensorflow version (GPU?): not installed
Flax version (CPU?/GPU?/TPU?): not installed
Jax version: not installed
JaxLib version: not installed
Using GPU in script?: yes. A100
CUDA_VERSION: 11.8.0
Using distributed or parallel set-up in script?: yes (deepspeed 0.11.2)
Who can help?
There is a similar git issue, but I also have additional observations arounds inference.
After GPTBigCode adds support to flash attention 2 in transformers 4.36, I ran inference with flash attention 2 enabled on a fine-tuned starcoderbase-3b which was previously created with 4.35. The inference metrics of output-label exact match dropped significantly, with some slices as low as 0%. Upon inspection, many outputs are simply repeating one token, suggesting bugs around the attention mechanism.
I then tried fine tuning a new model with transformers 4.36 and flash attention 2 enabled. While exact match are now a bit higher, all metrics still see drops significantly compared with previous model without flash attention 2. For instance, eval_loss increased 0.53 -> 0.75.
However, final training loss are similar at around 0.07. Fine tuning with flash attention 2 is very unstable, with training loss at 0.28 with a different batch_size.
Enabling and disabling padding (batch_size=1, pad_to_multiple_of=None) in trainer makes no meaningful difference in the metrics.
Information
The official example scripts
My own modified scripts
Tasks
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Model is loaded the same for training and inference. The only difference being inference is loading a fine-tuned starcoder model.
For training, loss should not go up compared with use_flash_attention_2=False.
For inference, a fine-tuned model (regardless of how it's trained) should produce the same / mostly same result in inference regardless of if flash attention 2 is enabled.
The text was updated successfully, but these errors were encountered:
lidingsnyk
changed the title
Starcoder has higher loss with flash attention 2
Starcoder has higher eval loss with flash attention 2
Feb 8, 2024
Thanks a lot @amyeroberts . Indeed the issue is fixed. I'm getting the exact same metrics in our batch inference with flash attention 2 enabled. Looking forward to next released version.
System Info
transformers version: 4.36.2
flash-attn: 2.5.2
flash_attn-2.5.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64
Platform: linux_x86_64 cp310 ubuntu-22.04
Python version: 3.10
Huggingface_hub version: 0.20.3
Safetensors version: 0.4.2
Accelerate version: 0.26.1
Accelerate config: not found
PyTorch version (GPU?): 2.1.2 (True) torch-2.1.2-cu118-cp310-cp310-linux_x86_64.whl
Tensorflow version (GPU?): not installed
Flax version (CPU?/GPU?/TPU?): not installed
Jax version: not installed
JaxLib version: not installed
Using GPU in script?: yes. A100
CUDA_VERSION: 11.8.0
Using distributed or parallel set-up in script?: yes (deepspeed 0.11.2)
Who can help?
There is a similar git issue, but I also have additional observations arounds inference.
After GPTBigCode adds support to flash attention 2 in transformers 4.36, I ran inference with flash attention 2 enabled on a fine-tuned starcoderbase-3b which was previously created with 4.35. The inference metrics of output-label exact match dropped significantly, with some slices as low as 0%. Upon inspection, many outputs are simply repeating one token, suggesting bugs around the attention mechanism.
I then tried fine tuning a new model with transformers 4.36 and flash attention 2 enabled. While exact match are now a bit higher, all metrics still see drops significantly compared with previous model without flash attention 2. For instance, eval_loss increased 0.53 -> 0.75.
However, final training loss are similar at around 0.07. Fine tuning with flash attention 2 is very unstable, with training loss at 0.28 with a different
batch_size
.Enabling and disabling padding (
batch_size=1, pad_to_multiple_of=None
) in trainer makes no meaningful difference in the metrics.Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Model is loaded the same for training and inference. The only difference being inference is loading a fine-tuned starcoder model.
Some important training args:
learning_rate: 1e-5
gradient_accumulation_steps: 16
bf16: "True"
torch_compile_mode: max-autotune
inference args:
beam_size: 5
tokenizer_max_length: 512
Expected behavior
For training, loss should not go up compared with
use_flash_attention_2=False
.For inference, a fine-tuned model (regardless of how it's trained) should produce the same / mostly same result in inference regardless of if flash attention 2 is enabled.
The text was updated successfully, but these errors were encountered: