8888from transformers .utils .import_utils import (
8989 is_flash_attn_2_available ,
9090 is_flash_attn_3_available ,
91+ is_kernels_available ,
9192 is_torch_npu_available ,
9293)
9394
@@ -2849,6 +2850,9 @@ def test_not_available_flash(self):
28492850 reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28502851 )
28512852
2853+ if is_kernels_available ():
2854+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_flash`" )
2855+
28522856 with self .assertRaises (ImportError ) as cm :
28532857 _ = AutoModel .from_pretrained (
28542858 "hf-internal-testing/tiny-random-GPTBigCodeModel" , attn_implementation = "flash_attention_2"
@@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self):
28642868 reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
28652869 )
28662870
2871+ if is_kernels_available ():
2872+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_flash_with_config`" )
2873+
28672874 config = AutoConfig .from_pretrained ("hf-internal-testing/tiny-random-GPTBigCodeModel" )
28682875
28692876 with self .assertRaises (ImportError ) as cm :
@@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self):
28752882
28762883 self .assertTrue ("the package flash_attn seems to be not installed" in str (cm .exception ))
28772884
2885+ def test_kernels_fallback (self ):
2886+ if not is_kernels_available ():
2887+ self .skipTest (reason = "Please install `kernels` package to run `test_kernels_fallback`" )
2888+
2889+ if is_flash_attn_2_available ():
2890+ self .skipTest (reason = "Please uninstall flash-attn package to run test_kernels_fallback" )
2891+
2892+ if is_torch_npu_available ():
2893+ self .skipTest (
2894+ reason = "FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
2895+ )
2896+
2897+ logger = logging .get_logger ("transformers.modeling_utils" )
2898+ with LoggingLevel (logging .WARNING ):
2899+ with CaptureLogger (logger ) as cl :
2900+ _ = AutoModel .from_pretrained (
2901+ "hf-internal-testing/tiny-random-GPTBigCodeModel" , attn_implementation = "flash_attention_2"
2902+ )
2903+
2904+ self .assertTrue (
2905+ "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
2906+ in cl .out
2907+ )
2908+
2909+ def test_not_available_kernels (self ):
2910+ if is_kernels_available ():
2911+ self .skipTest (reason = "Please uninstall `kernels` package to run `test_not_available_kernels`" )
2912+
2913+ with self .assertRaises (ImportError ) as cm :
2914+ _ = AutoModel .from_pretrained (
2915+ "hf-tiny-model-private/tiny-random-MCTCTModel" , attn_implementation = "kernels-community/flash-attn"
2916+ )
2917+
2918+ self .assertTrue ("`kernels` is either not installed or uses an incompatible version." in str (cm .exception ))
2919+
28782920
28792921@require_torch
28802922class TestTensorSharing (TestCasePlus ):
0 commit comments