22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import tempfile
5+ from pathlib import Path
56from typing import Any
67
78import pytest
@@ -21,27 +22,21 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
2122 ("facebook/opt-125m" , {}),
2223 (
2324 "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic" ,
24- {
25- "dtype" : torch .float16 ,
26- },
25+ {"dtype" : torch .float16 },
2726 ),
2827 ("meta-llama/Llama-3.2-1B-Instruct" , {}),
2928 ]
3029
3130 if all :
32- if not current_platform .has_device_capability ((10 , 0 )):
33- # int8 removed on Blackwell
34- TEST_MODELS .extend (
35- [
36- ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" , {}),
37- (
38- "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" ,
39- {
40- "dtype" : torch .float16 ,
41- },
42- ),
43- ]
44- )
31+ TEST_MODELS .extend (
32+ [
33+ ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" , {}),
34+ (
35+ "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" ,
36+ {"dtype" : torch .float16 },
37+ ),
38+ ]
39+ )
4540
4641 # TODO: figure out why this fails.
4742 if False and is_quant_method_supported ("gguf" ): # noqa: SIM223
@@ -95,6 +90,14 @@ def test_full_graph(
9590 model_kwargs : dict [str , Any ],
9691 compilation_mode : int ,
9792):
93+ if (
94+ "w8a8" in model
95+ or "w8w8" in model
96+ and current_platform .has_device_capability ((10 , 0 ))
97+ ):
98+ # int8 removed on Blackwell:
99+ pytest .skip ("int8 support removed on Blackwell" )
100+
98101 with monkeypatch .context ():
99102 print (f"MODEL={ model } " )
100103
@@ -103,14 +106,14 @@ def test_full_graph(
103106
104107# TODO(luka) add other supported compilation config scenarios here
105108@pytest .mark .parametrize (
106- "compilation_config, model_info " ,
109+ "compilation_config, model, model_kwargs " ,
107110 [
108111 # additional compile sizes, only some of the models
109112 (
110113 CompilationConfig (mode = CompilationMode .VLLM_COMPILE , compile_sizes = [1 , 2 ]),
111- model ,
114+ * model_info ,
112115 )
113- for model in models_list (all = False )
116+ for model_info in models_list (all = False )
114117 ]
115118 + [
116119 # RMSNorm + quant fusion, only 8-bit quant models
@@ -120,18 +123,19 @@ def test_full_graph(
120123 custom_ops = ["+rms_norm" ],
121124 pass_config = PassConfig (enable_fusion = True , enable_noop = True ),
122125 ),
123- model ,
126+ * model_info ,
124127 )
125- for model in models_list (keywords = ["FP8-dynamic" , "quantized.w8a8" ])
128+ for model_info in models_list (keywords = ["FP8-dynamic" , "quantized.w8a8" ])
126129 ]
127130 + [
128131 # Test depyf integration works
129132 (
130133 CompilationConfig (
131134 mode = CompilationMode .VLLM_COMPILE ,
132- debug_dump_path = tempfile .gettempdir (),
135+ debug_dump_path = Path ( tempfile .gettempdir () ),
133136 ),
134- ("facebook/opt-125m" , {}),
137+ "facebook/opt-125m" ,
138+ {},
135139 ),
136140 ]
137141 + [
@@ -145,24 +149,32 @@ def test_full_graph(
145149 cudagraph_mode = CUDAGraphMode .PIECEWISE ,
146150 compile_sizes = [1 , 2 ],
147151 ),
148- model ,
152+ * model_info ,
149153 )
150- for model in models_list (all = False )
154+ for model_info in models_list (all = False )
151155 if is_torch_equal_or_newer ("2.9.0.dev" )
152156 ],
153157)
154158# only test some of the models
155159@create_new_process_for_each_test ()
156160def test_custom_compile_config (
157161 compilation_config : CompilationConfig ,
158- model_info : tuple [str , dict [str , Any ]],
162+ model : str ,
163+ model_kwargs : dict [str , Any ],
159164):
165+ if (
166+ "w8a8" in model
167+ or "w8w8" in model
168+ and current_platform .has_device_capability ((10 , 0 ))
169+ ):
170+ # int8 removed on Blackwell:
171+ pytest .skip ("int8 support removed on Blackwell" )
172+
160173 if compilation_config .use_inductor_graph_partition and not is_torch_equal_or_newer (
161174 "2.9.0.dev"
162175 ):
163176 pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
164177
165- model , model_kwargs = model_info
166178 print (f"MODEL={ model } " )
167179 run_model (compilation_config , model , ** model_kwargs )
168180
0 commit comments