Skip to content

Commit 389115c

Browse files
vasquArthurZucker
authored andcommitted
[Kernels Attention] Change fallback logic to error out on explicit kernels request and include FA3 (#41010)
* fix * be more strict * change logic to include fa3 * fix the case where nothing is requested * modify old tests + add kernels related tests * style
1 parent 3017f04 commit 389115c

File tree

3 files changed

+67
-18
lines changed

3 files changed

+67
-18
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None:
152152
if not is_kernel(attn_implementation):
153153
return
154154
if not _kernels_available:
155-
raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")
155+
raise ImportError(
156+
"`kernels` is either not installed or uses an incompatible version. "
157+
"Please install the latest version with `pip install -U kernels`."
158+
)
156159

157160
# Need to be imported here as otherwise we have a circular import in `modeling_utils`
158161
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS

src/transformers/modeling_utils.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,35 +2666,39 @@ def _check_and_adjust_attn_implementation(
26662666
None to sdpa (to potentially eager).
26672667
"""
26682668
applicable_attn_implementation = attn_implementation
2669+
26692670
# If FA not installed, do not fail but use kernels instead
26702671
if (
2671-
applicable_attn_implementation == "flash_attention_2"
2672+
attn_implementation is not None
2673+
and attn_implementation.startswith("flash_attention")
26722674
and self._supports_flash_attn
2673-
and not is_flash_attn_2_available()
2675+
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
26742676
and is_kernels_available()
26752677
):
2676-
applicable_attn_implementation = "kernels-community/flash-attn"
2678+
if attn_implementation.endswith("2"):
2679+
applicable_attn_implementation = "kernels-community/flash-attn"
2680+
else:
2681+
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
2682+
26772683
if is_kernel(applicable_attn_implementation):
26782684
try:
26792685
load_and_register_kernel(applicable_attn_implementation)
26802686
# log that we used kernel fallback if successful
2681-
if attn_implementation == "flash_attention_2":
2687+
if attn_implementation.startswith("flash_attention"):
26822688
logger.warning_once(
2683-
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
2684-
"library instead!"
2689+
f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
2690+
"from the `kernels` library instead!"
26852691
)
26862692
except Exception as e:
2687-
if attn_implementation == "flash_attention_2":
2688-
self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
2689-
logger.warning_once(
2690-
f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
2691-
f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
2692-
)
2693-
try:
2694-
self._sdpa_can_dispatch(is_init_check)
2695-
applicable_attn_implementation = "sdpa"
2696-
except (ValueError, ImportError):
2697-
applicable_attn_implementation = "eager"
2693+
# raise the proper exception for requested flash attention
2694+
if attn_implementation.startswith("flash_attention"):
2695+
if attn_implementation.endswith("2"):
2696+
self._flash_attn_2_can_dispatch()
2697+
else:
2698+
self._flash_attn_3_can_dispatch()
2699+
2700+
# error properly out if a kernel was specifically requested
2701+
raise e
26982702
else:
26992703
applicable_attn_implementation = self.get_correct_attn_implementation(
27002704
applicable_attn_implementation, is_init_check

tests/utils/test_modeling_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from transformers.utils.import_utils import (
8989
is_flash_attn_2_available,
9090
is_flash_attn_3_available,
91+
is_kernels_available,
9192
is_torch_npu_available,
9293
)
9394

@@ -2849,6 +2850,9 @@ def test_not_available_flash(self):
28492850
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28502851
)
28512852

2853+
if is_kernels_available():
2854+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`")
2855+
28522856
with self.assertRaises(ImportError) as cm:
28532857
_ = AutoModel.from_pretrained(
28542858
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
@@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self):
28642868
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28652869
)
28662870

2871+
if is_kernels_available():
2872+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`")
2873+
28672874
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
28682875

28692876
with self.assertRaises(ImportError) as cm:
@@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self):
28752882

28762883
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
28772884

2885+
def test_kernels_fallback(self):
2886+
if not is_kernels_available():
2887+
self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`")
2888+
2889+
if is_flash_attn_2_available():
2890+
self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback")
2891+
2892+
if is_torch_npu_available():
2893+
self.skipTest(
2894+
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
2895+
)
2896+
2897+
logger = logging.get_logger("transformers.modeling_utils")
2898+
with LoggingLevel(logging.WARNING):
2899+
with CaptureLogger(logger) as cl:
2900+
_ = AutoModel.from_pretrained(
2901+
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
2902+
)
2903+
2904+
self.assertTrue(
2905+
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
2906+
in cl.out
2907+
)
2908+
2909+
def test_not_available_kernels(self):
2910+
if is_kernels_available():
2911+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`")
2912+
2913+
with self.assertRaises(ImportError) as cm:
2914+
_ = AutoModel.from_pretrained(
2915+
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn"
2916+
)
2917+
2918+
self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))
2919+
28782920

28792921
@require_torch
28802922
class TestTensorSharing(TestCasePlus):

0 commit comments

Comments
 (0)