Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None:
if not is_kernel(attn_implementation):
return
if not _kernels_available:
raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")
raise ImportError(
"`kernels` is either not installed or uses an incompatible version. "
"Please install the latest version with `pip install -U kernels`."
)

# Need to be imported here as otherwise we have a circular import in `modeling_utils`
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
Expand Down
38 changes: 21 additions & 17 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,35 +2675,39 @@ def _check_and_adjust_attn_implementation(
None to sdpa (to potentially eager).
"""
applicable_attn_implementation = attn_implementation

# If FA not installed, do not fail but use kernels instead
if (
applicable_attn_implementation == "flash_attention_2"
attn_implementation is not None
and attn_implementation.startswith("flash_attention")
and self._supports_flash_attn
and not is_flash_attn_2_available()
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
and is_kernels_available()
):
applicable_attn_implementation = "kernels-community/flash-attn"
if attn_implementation.endswith("2"):
applicable_attn_implementation = "kernels-community/flash-attn"
else:
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"

if is_kernel(applicable_attn_implementation):
try:
load_and_register_kernel(applicable_attn_implementation)
# log that we used kernel fallback if successful
if attn_implementation == "flash_attention_2":
if attn_implementation.startswith("flash_attention"):
logger.warning_once(
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
"library instead!"
f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
"from the `kernels` library instead!"
)
except Exception as e:
if attn_implementation == "flash_attention_2":
self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
logger.warning_once(
f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
)
try:
self._sdpa_can_dispatch(is_init_check)
applicable_attn_implementation = "sdpa"
except (ValueError, ImportError):
applicable_attn_implementation = "eager"
# raise the proper exception for requested flash attention
if attn_implementation.startswith("flash_attention"):
if attn_implementation.endswith("2"):
self._flash_attn_2_can_dispatch()
else:
self._flash_attn_3_can_dispatch()

# error properly out if a kernel was specifically requested
raise e
else:
applicable_attn_implementation = self.get_correct_attn_implementation(
applicable_attn_implementation, is_init_check
Expand Down
42 changes: 42 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_kernels_available,
is_torch_npu_available,
)

Expand Down Expand Up @@ -2849,6 +2850,9 @@ def test_not_available_flash(self):
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)

if is_kernels_available():
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`")

with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
Expand All @@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self):
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)

if is_kernels_available():
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`")
Comment on lines +2871 to +2872
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasnt considered before, kernels now works as fallback which let's this fail if installed


config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")

with self.assertRaises(ImportError) as cm:
Expand All @@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self):

self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

def test_kernels_fallback(self):
if not is_kernels_available():
self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`")

if is_flash_attn_2_available():
self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback")

if is_torch_npu_available():
self.skipTest(
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)

logger = logging.get_logger("transformers.modeling_utils")
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
)

self.assertTrue(
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
in cl.out
)

def test_not_available_kernels(self):
if is_kernels_available():
self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`")

with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn"
)

self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))


@require_torch
class TestTensorSharing(TestCasePlus):
Expand Down