|
24 | 24 | from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available |
25 | 25 | from transformers.testing_utils import backend_device_count, torch_device |
26 | 26 | from transformers.utils import ( |
27 | | - is_flash_attn_2_available, |
28 | 27 | is_kernels_available, |
29 | 28 | is_peft_available, |
30 | 29 | is_rich_available, |
|
45 | 44 |
|
46 | 45 | require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") |
47 | 46 | require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") |
| 47 | +require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels") |
48 | 48 | require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel") |
49 | 49 | require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") |
50 | 50 | require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") |
@@ -85,21 +85,16 @@ def is_bitsandbytes_multi_backend_available() -> bool: |
85 | 85 | ) |
86 | 86 |
|
87 | 87 |
|
88 | | -def is_flash_attn_available(): |
89 | | - flash_attn_available = is_flash_attn_2_available() |
90 | | - kernels_available = is_kernels_available() |
91 | | - try: |
92 | | - from kernels import get_kernel |
93 | | - |
94 | | - get_kernel("kernels-community/flash-attn") |
95 | | - except Exception: |
96 | | - kernels_available = False |
| 88 | +def is_ampere_or_newer(device_index=0): |
| 89 | + if not torch.cuda.is_available(): |
| 90 | + return False |
97 | 91 |
|
98 | | - return kernels_available or flash_attn_available |
| 92 | + major, minor = torch.cuda.get_device_capability(device_index) |
| 93 | + # Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6) |
| 94 | + return (major, minor) >= (8, 0) |
99 | 95 |
|
100 | 96 |
|
101 | | -# Function ported from transformers.testing_utils |
102 | | -require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention") |
| 97 | +require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU") |
103 | 98 |
|
104 | 99 |
|
105 | 100 | class RandomBinaryJudge(BaseBinaryJudge): |
|
0 commit comments