File tree Expand file tree Collapse file tree 4 files changed +20
-3
lines changed
tests/models/quantization Expand file tree Collapse file tree 4 files changed +20
-3
lines changed Original file line number Diff line number Diff line change 22import pytest
33
44from tests .quantization .utils import is_quant_method_supported
5+ from vllm .platforms import current_platform
56
67# These ground truth generations were generated using `transformers==4.38.1
78# aqlm==1.1.0 torch==2.2.0`
3435]
3536
3637
37- @pytest .mark .skipif (not is_quant_method_supported ("aqlm" ),
38+ @pytest .mark .skipif (not is_quant_method_supported ("aqlm" )
39+ or current_platform .is_rocm ()
40+ or not current_platform .is_cuda (),
3841 reason = "AQLM is not supported on this GPU type." )
3942@pytest .mark .parametrize ("model" , ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" ])
4043@pytest .mark .parametrize ("dtype" , ["half" ])
Original file line number Diff line number Diff line change @@ -55,6 +55,14 @@ def test_models(
5555 Only checks log probs match to cover the discrepancy in
5656 numerical sensitive kernels.
5757 """
58+
59+ if backend == "FLASHINFER" and current_platform .is_rocm ():
60+ pytest .skip ("Flashinfer does not support ROCm/HIP." )
61+
62+ if kv_cache_dtype == "fp8_e5m2" and current_platform .is_rocm ():
63+ pytest .skip (
64+ f"{ kv_cache_dtype } is currently not supported on ROCm/HIP." )
65+
5866 with monkeypatch .context () as m :
5967 m .setenv ("TOKENIZERS_PARALLELISM" , 'true' )
6068 m .setenv (STR_BACKEND_ENV_VAR , backend )
Original file line number Diff line number Diff line change 1414
1515from tests .quantization .utils import is_quant_method_supported
1616from vllm .model_executor .layers .rotary_embedding import _ROPE_DICT
17+ from vllm .platforms import current_platform
1718
1819from ..utils import check_logprobs_close
1920
3435
3536
3637@pytest .mark .flaky (reruns = 3 )
37- @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin" ),
38+ @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin" )
39+ or current_platform .is_rocm ()
40+ or not current_platform .is_cuda (),
3841 reason = "gptq_marlin is not supported on this GPU type." )
3942@pytest .mark .parametrize ("model" , MODELS )
4043@pytest .mark .parametrize ("dtype" , ["half" , "bfloat16" ])
Original file line number Diff line number Diff line change 1010import pytest
1111
1212from tests .quantization .utils import is_quant_method_supported
13+ from vllm .platforms import current_platform
1314
1415from ..utils import check_logprobs_close
1516
@@ -38,7 +39,9 @@ class ModelPair:
3839
3940
4041@pytest .mark .flaky (reruns = 2 )
41- @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin_24" ),
42+ @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin_24" )
43+ or current_platform .is_rocm ()
44+ or not current_platform .is_cuda (),
4245 reason = "Marlin24 is not supported on this GPU type." )
4346@pytest .mark .parametrize ("model_pair" , model_pairs )
4447@pytest .mark .parametrize ("dtype" , ["half" ])
You can’t perform that action at this time.
0 commit comments