Skip to content

Commit c1db386

Browse files
authored
[Kernels Attention] Change fallback logic to error out on explicit kernels request and include FA3 (#41010)
* fix * be more strict * change logic to include fa3 * fix the case where nothing is requested * modify old tests + add kernels related tests * style
1 parent 5426ede commit c1db386

File tree

3 files changed

+67
-18
lines changed

3 files changed

+67
-18
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None:
152152
if not is_kernel(attn_implementation):
153153
return
154154
if not _kernels_available:
155-
raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")
155+
raise ImportError(
156+
"`kernels` is either not installed or uses an incompatible version. "
157+
"Please install the latest version with `pip install -U kernels`."
158+
)
156159

157160
# Need to be imported here as otherwise we have a circular import in `modeling_utils`
158161
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS

src/transformers/modeling_utils.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,35 +2574,39 @@ def _check_and_adjust_attn_implementation(
25742574
None to sdpa (to potentially eager).
25752575
"""
25762576
applicable_attn_implementation = attn_implementation
2577+
25772578
# If FA not installed, do not fail but use kernels instead
25782579
if (
2579-
applicable_attn_implementation == "flash_attention_2"
2580+
attn_implementation is not None
2581+
and attn_implementation.startswith("flash_attention")
25802582
and self._supports_flash_attn
2581-
and not is_flash_attn_2_available()
2583+
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
25822584
and is_kernels_available()
25832585
):
2584-
applicable_attn_implementation = "kernels-community/flash-attn"
2586+
if attn_implementation.endswith("2"):
2587+
applicable_attn_implementation = "kernels-community/flash-attn"
2588+
else:
2589+
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
2590+
25852591
if is_kernel(applicable_attn_implementation):
25862592
try:
25872593
load_and_register_kernel(applicable_attn_implementation)
25882594
# log that we used kernel fallback if successful
2589-
if attn_implementation == "flash_attention_2":
2595+
if attn_implementation.startswith("flash_attention"):
25902596
logger.warning_once(
2591-
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
2592-
"library instead!"
2597+
f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
2598+
"from the `kernels` library instead!"
25932599
)
25942600
except Exception as e:
2595-
if attn_implementation == "flash_attention_2":
2596-
self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
2597-
logger.warning_once(
2598-
f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
2599-
f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
2600-
)
2601-
try:
2602-
self._sdpa_can_dispatch(is_init_check)
2603-
applicable_attn_implementation = "sdpa"
2604-
except (ValueError, ImportError):
2605-
applicable_attn_implementation = "eager"
2601+
# raise the proper exception for requested flash attention
2602+
if attn_implementation.startswith("flash_attention"):
2603+
if attn_implementation.endswith("2"):
2604+
self._flash_attn_2_can_dispatch()
2605+
else:
2606+
self._flash_attn_3_can_dispatch()
2607+
2608+
# error properly out if a kernel was specifically requested
2609+
raise e
26062610
else:
26072611
applicable_attn_implementation = self.get_correct_attn_implementation(
26082612
applicable_attn_implementation, is_init_check

tests/utils/test_modeling_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from transformers.utils.import_utils import (
8888
is_flash_attn_2_available,
8989
is_flash_attn_3_available,
90+
is_kernels_available,
9091
is_torch_npu_available,
9192
)
9293

@@ -2846,6 +2847,9 @@ def test_not_available_flash(self):
28462847
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28472848
)
28482849

2850+
if is_kernels_available():
2851+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`")
2852+
28492853
with self.assertRaises(ImportError) as cm:
28502854
_ = AutoModel.from_pretrained(
28512855
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
@@ -2861,6 +2865,9 @@ def test_not_available_flash_with_config(self):
28612865
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28622866
)
28632867

2868+
if is_kernels_available():
2869+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`")
2870+
28642871
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
28652872

28662873
with self.assertRaises(ImportError) as cm:
@@ -2872,6 +2879,41 @@ def test_not_available_flash_with_config(self):
28722879

28732880
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
28742881

2882+
def test_kernels_fallback(self):
2883+
if not is_kernels_available():
2884+
self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`")
2885+
2886+
if is_flash_attn_2_available():
2887+
self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback")
2888+
2889+
if is_torch_npu_available():
2890+
self.skipTest(
2891+
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
2892+
)
2893+
2894+
logger = logging.get_logger("transformers.modeling_utils")
2895+
with LoggingLevel(logging.WARNING):
2896+
with CaptureLogger(logger) as cl:
2897+
_ = AutoModel.from_pretrained(
2898+
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
2899+
)
2900+
2901+
self.assertTrue(
2902+
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
2903+
in cl.out
2904+
)
2905+
2906+
def test_not_available_kernels(self):
2907+
if is_kernels_available():
2908+
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`")
2909+
2910+
with self.assertRaises(ImportError) as cm:
2911+
_ = AutoModel.from_pretrained(
2912+
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn"
2913+
)
2914+
2915+
self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))
2916+
28752917

28762918
@require_torch
28772919
class TestTensorSharing(TestCasePlus):

0 commit comments

Comments
 (0)