From 7e2f006812ffd4323df51c44b3c9e75a4534dccb Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 19 Sep 2025 17:35:53 +0200 Subject: [PATCH 1/6] fix --- src/transformers/integrations/hub_kernels.py | 5 ++++- src/transformers/modeling_utils.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..71adbb9188e7 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None: if not is_kernel(attn_implementation): return if not _kernels_available: - raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") + raise ImportError( + "`kernels` is either not installed or uses an incompatible version. " + "Please install the latest version with `pip install -U kernels`." + ) # Need to be imported here as otherwise we have a circular import in `modeling_utils` from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 31783d041fe4..9add887f7243 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2708,6 +2708,11 @@ def _check_and_adjust_attn_implementation( except Exception as e: if attn_implementation == "flash_attention_2": self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception + + # error properly out if a kernel was specifically requested + if isinstance(e, ImportError): + raise e + logger.warning_once( f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." From c7aad063b33862b6cb6171f3aea7ff4f79e682d0 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 19 Sep 2025 17:44:17 +0200 Subject: [PATCH 2/6] be more strict --- src/transformers/modeling_utils.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9add887f7243..cf0336517bed 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2710,18 +2710,7 @@ def _check_and_adjust_attn_implementation( self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception # error properly out if a kernel was specifically requested - if isinstance(e, ImportError): - raise e - - logger.warning_once( - f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " - f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." - ) - try: - self._sdpa_can_dispatch(is_init_check) - applicable_attn_implementation = "sdpa" - except (ValueError, ImportError) as e: - applicable_attn_implementation = "eager" + raise e else: applicable_attn_implementation = self.get_correct_attn_implementation( applicable_attn_implementation, is_init_check From 6d37ecdfe516f17bd5569bf94012c367fd221bc8 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 19 Sep 2025 17:50:57 +0200 Subject: [PATCH 3/6] change logic to include fa3 --- src/transformers/modeling_utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cf0336517bed..6ad7f3e10c55 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2688,26 +2688,35 @@ def _check_and_adjust_attn_implementation( None to sdpa (to potentially eager). """ applicable_attn_implementation = attn_implementation + # If FA not installed, do not fail but use kernels instead if ( - applicable_attn_implementation == "flash_attention_2" + applicable_attn_implementation.startswith("flash_attention") and self._supports_flash_attn - and not is_flash_attn_2_available() + and not (is_flash_attn_2_available() or is_flash_attn_3_available()) and is_kernels_available() ): - applicable_attn_implementation = "kernels-community/flash-attn" + if applicable_attn_implementation.endswith("2"): + applicable_attn_implementation = "kernels-community/flash-attn" + else: + applicable_attn_implementation = "kernels-community/vllm-flash-attn3" + if is_kernel(applicable_attn_implementation): try: load_and_register_kernel(applicable_attn_implementation) # log that we used kernel fallback if successful - if attn_implementation == "flash_attention_2": + if attn_implementation.startswith("flash_attention"): logger.warning_once( - "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` " - "library instead!" + f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` " + "from the `kernels` library instead!" ) except Exception as e: - if attn_implementation == "flash_attention_2": - self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception + # raise the proper exception for requested flash attention + if attn_implementation.startswith("flash_attention"): + if attn_implementation.endswith("2"): + self._flash_attn_2_can_dispatch() + else: + self._flash_attn_3_can_dispatch() # error properly out if a kernel was specifically requested raise e From a8002cc3a5ca2c86314a0265acaa5d30782dd3fa Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 19 Sep 2025 18:01:02 +0200 Subject: [PATCH 4/6] fix the case where nothing is requested --- src/transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ad7f3e10c55..5692feb3af61 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2691,12 +2691,13 @@ def _check_and_adjust_attn_implementation( # If FA not installed, do not fail but use kernels instead if ( - applicable_attn_implementation.startswith("flash_attention") + attn_implementation is not None + and attn_implementation.startswith("flash_attention") and self._supports_flash_attn and not (is_flash_attn_2_available() or is_flash_attn_3_available()) and is_kernels_available() ): - if applicable_attn_implementation.endswith("2"): + if attn_implementation.endswith("2"): applicable_attn_implementation = "kernels-community/flash-attn" else: applicable_attn_implementation = "kernels-community/vllm-flash-attn3" From 08c3644826fc3202b756fc50fa7b16675bea21ab Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 22 Sep 2025 16:38:22 +0000 Subject: [PATCH 5/6] modify old tests + add kernels related tests --- tests/utils/test_modeling_utils.py | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index bf6889338b0e..7902c8afe25f 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -88,6 +88,7 @@ from transformers.utils.import_utils import ( is_flash_attn_2_available, is_flash_attn_3_available, + is_kernels_available, is_torch_npu_available, ) @@ -2737,6 +2738,9 @@ def test_not_available_flash(self): reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." ) + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`") + with self.assertRaises(ImportError) as cm: _ = AutoModel.from_pretrained( "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" @@ -2752,6 +2756,9 @@ def test_not_available_flash_with_config(self): reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." ) + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`") + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") with self.assertRaises(ImportError) as cm: @@ -2763,6 +2770,36 @@ def test_not_available_flash_with_config(self): self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_kernels_fallback(self): + if not is_kernels_available(): + self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`") + + if is_flash_attn_2_available(): + self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback") + + if is_torch_npu_available(): + self.skipTest( + reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." + ) + + logger = logging.get_logger("transformers.modeling_utils") + with LoggingLevel(logging.WARNING): + with CaptureLogger(logger) as cl: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" + ) + + self.assertTrue("You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!" in cl.out) + + def test_not_available_kernels(self): + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn") + + self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception)) + @require_torch class TestTensorSharing(TestCasePlus): From e757821e33aa982ccbc9ed63ace49f2999263a05 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 22 Sep 2025 16:41:25 +0000 Subject: [PATCH 6/6] style --- tests/utils/test_modeling_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7902c8afe25f..d80ede5aff7a 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2786,17 +2786,22 @@ def test_kernels_fallback(self): with LoggingLevel(logging.WARNING): with CaptureLogger(logger) as cl: _ = AutoModel.from_pretrained( - "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" - ) + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" + ) - self.assertTrue("You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!" in cl.out) + self.assertTrue( + "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!" + in cl.out + ) def test_not_available_kernels(self): if is_kernels_available(): self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`") with self.assertRaises(ImportError) as cm: - _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn") + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn" + ) self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))