8787from transformers .utils .import_utils import (
8888 is_flash_attn_2_available ,
8989 is_flash_attn_3_available ,
90+ is_kernels_available ,
9091 is_torch_npu_available ,
9192)
9293
@@ -2846,6 +2847,9 @@ def test_not_available_flash(self):
28462847 reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28472848 )
28482849
2850+ if is_kernels_available ():
2851+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_flash`" )
2852+
28492853 with self .assertRaises (ImportError ) as cm :
28502854 _ = AutoModel .from_pretrained (
28512855 "hf-internal-testing/tiny-random-GPTBigCodeModel" , attn_implementation = "flash_attention_2"
@@ -2861,6 +2865,9 @@ def test_not_available_flash_with_config(self):
28612865 reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28622866 )
28632867
2868+ if is_kernels_available ():
2869+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_flash_with_config`" )
2870+
28642871 config = AutoConfig .from_pretrained ("hf-internal-testing/tiny-random-GPTBigCodeModel" )
28652872
28662873 with self .assertRaises (ImportError ) as cm :
@@ -2872,6 +2879,41 @@ def test_not_available_flash_with_config(self):
28722879
28732880 self .assertTrue ("the package flash_attn seems to be not installed" in str (cm .exception ))
28742881
2882+ def test_kernels_fallback (self ):
2883+ if not is_kernels_available ():
2884+ self .skipTest (reason = "Please install `kernels` package to run `test_kernels_fallback`" )
2885+
2886+ if is_flash_attn_2_available ():
2887+ self .skipTest (reason = "Please uninstall flash-attn package to run test_kernels_fallback" )
2888+
2889+ if is_torch_npu_available ():
2890+ self .skipTest (
2891+ reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
2892+ )
2893+
2894+ logger = logging .get_logger ("transformers.modeling_utils" )
2895+ with LoggingLevel (logging .WARNING ):
2896+ with CaptureLogger (logger ) as cl :
2897+ _ = AutoModel .from_pretrained (
2898+ "hf-internal-testing/tiny-random-GPTBigCodeModel" , attn_implementation = "flash_attention_2"
2899+ )
2900+
2901+ self .assertTrue (
2902+ "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
2903+ in cl .out
2904+ )
2905+
2906+ def test_not_available_kernels (self ):
2907+ if is_kernels_available ():
2908+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_kernels`" )
2909+
2910+ with self .assertRaises (ImportError ) as cm :
2911+ _ = AutoModel .from_pretrained (
2912+ "hf-tiny-model-private/tiny-random-MCTCTModel" , attn_implementation = "kernels-community/flash-attn"
2913+ )
2914+
2915+ self .assertTrue ("`kernels` is either not installed or uses an incompatible version." in str (cm .exception ))
2916+
28752917
28762918@require_torch
28772919class TestTensorSharing (TestCasePlus ):
0 commit comments