Skip to content

Commit 07b5011

Browse files
Replace flash attention2 with kernels-community/flash-attn2 (#4426)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 7a57fd4 commit 07b5011

File tree

7 files changed

+27
-23
lines changed

7 files changed

+27
-23
lines changed

docs/source/dpo_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained(
253253
"mistralai/mixtral-8x7b-v0.1",
254254
load_in_4bit=True,
255255
quantization_config=bnb_config,
256-
attn_implementation="flash_attention_2",
256+
attn_implementation="kernels-community/flash-attn2",
257257
dtype=torch.bfloat16,
258258
device_map="auto",
259259
)

docs/source/gkd_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a
2828
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
2929

3030
> [!WARNING]
31-
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
31+
> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
3232
3333
The basic API is as follows:
3434

docs/source/reducing_memory_usage.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t
188188
```python
189189
from trl import DPOConfig
190190

191-
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
191+
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})
192192
```
193193

194194
</hfoption>
@@ -197,7 +197,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple
197197
```python
198198
from trl import SFTConfig
199199

200-
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
200+
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})
201201
```
202202

203203
</hfoption>

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ judges = [
5252
"openai>=1.23.2",
5353
"llm-blender>=0.0.2"
5454
]
55+
kernels = [
56+
"kernels"
57+
]
5558
liger = [
5659
"liger-kernel>=0.6.2"
5760
]
@@ -98,6 +101,8 @@ dev = [
98101
# judges
99102
"openai>=1.23.2",
100103
"llm-blender>=0.0.2",
104+
# kernels
105+
"kernels",
101106
# liger
102107
"liger-kernel>=0.6.2",
103108
# peft

tests/test_grpo_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141

4242
from .testing_utils import (
4343
TrlTestCase,
44+
require_ampere_or_newer,
4445
require_bitsandbytes,
45-
require_flash_attn,
46+
require_kernels,
4647
require_liger_kernel,
4748
require_peft,
4849
require_torch_accelerator,
@@ -1987,7 +1988,8 @@ def test_training_with_transformers_paged(self, model_name):
19871988
"HuggingFaceTB/SmolVLM-Instruct", # Only test the smaller model to avoid OOM
19881989
],
19891990
)
1990-
@require_flash_attn
1991+
@require_kernels
1992+
@require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs
19911993
@require_bitsandbytes
19921994
@require_peft
19931995
def test_vlm_training(self, model_name):
@@ -2040,7 +2042,7 @@ def data_gen(num_samples):
20402042
)
20412043
model = AutoModelForImageTextToText.from_pretrained(
20422044
model_name,
2043-
attn_implementation="flash_attention_2",
2045+
attn_implementation="kernels-community/flash-attn2",
20442046
dtype="bfloat16",
20452047
device_map=get_kbit_device_map(),
20462048
quantization_config=quantization_config,

tests/test_sft_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
from .testing_utils import (
3434
TrlTestCase,
3535
ignore_warnings,
36+
require_ampere_or_newer,
3637
require_bitsandbytes,
37-
require_flash_attn,
38+
require_kernels,
3839
require_liger_kernel,
3940
require_peft,
4041
require_torch_accelerator,
@@ -870,7 +871,8 @@ def test_train_with_iterable_dataset(self):
870871
new_param = trainer.model.get_parameter(n)
871872
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
872873

873-
@require_flash_attn
874+
@require_kernels
875+
@require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs
874876
def test_train_padding_free(self):
875877
# Get the dataset
876878
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
@@ -879,7 +881,7 @@ def test_train_padding_free(self):
879881
training_args = SFTConfig(
880882
output_dir=self.tmp_dir,
881883
padding_free=True,
882-
model_init_kwargs={"attn_implementation": "flash_attention_2"},
884+
model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"},
883885
bf16=True, # flash_attention_2 only supports bf16 and fp16
884886
report_to="none",
885887
)

tests/testing_utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
2525
from transformers.testing_utils import backend_device_count, torch_device
2626
from transformers.utils import (
27-
is_flash_attn_2_available,
2827
is_kernels_available,
2928
is_peft_available,
3029
is_rich_available,
@@ -45,6 +44,7 @@
4544

4645
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
4746
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
47+
require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels")
4848
require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel")
4949
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
5050
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
@@ -85,21 +85,16 @@ def is_bitsandbytes_multi_backend_available() -> bool:
8585
)
8686

8787

88-
def is_flash_attn_available():
89-
flash_attn_available = is_flash_attn_2_available()
90-
kernels_available = is_kernels_available()
91-
try:
92-
from kernels import get_kernel
93-
94-
get_kernel("kernels-community/flash-attn")
95-
except Exception:
96-
kernels_available = False
88+
def is_ampere_or_newer(device_index=0):
89+
if not torch.cuda.is_available():
90+
return False
9791

98-
return kernels_available or flash_attn_available
92+
major, minor = torch.cuda.get_device_capability(device_index)
93+
# Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6)
94+
return (major, minor) >= (8, 0)
9995

10096

101-
# Function ported from transformers.testing_utils
102-
require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention")
97+
require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU")
10398

10499

105100
class RandomBinaryJudge(BaseBinaryJudge):

0 commit comments

Comments
 (0)