Skip to content

Commit b5f89e5

Browse files
committed
Cleanup test_full_graph.py
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 97b3ff2 commit b5f89e5

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tests/compile/test_full_graph.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import tempfile
5+
from pathlib import Path
56
from typing import Any
67

78
import pytest
@@ -21,27 +22,21 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
2122
("facebook/opt-125m", {}),
2223
(
2324
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
24-
{
25-
"dtype": torch.float16,
26-
},
25+
{"dtype": torch.float16},
2726
),
2827
("meta-llama/Llama-3.2-1B-Instruct", {}),
2928
]
3029

3130
if all:
32-
if not current_platform.has_device_capability((10, 0)):
33-
# int8 removed on Blackwell
34-
TEST_MODELS.extend(
35-
[
36-
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
37-
(
38-
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
39-
{
40-
"dtype": torch.float16,
41-
},
42-
),
43-
]
44-
)
31+
TEST_MODELS.extend(
32+
[
33+
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
34+
(
35+
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
36+
{"dtype": torch.float16},
37+
),
38+
]
39+
)
4540

4641
# TODO: figure out why this fails.
4742
if False and is_quant_method_supported("gguf"): # noqa: SIM223
@@ -95,6 +90,14 @@ def test_full_graph(
9590
model_kwargs: dict[str, Any],
9691
compilation_mode: int,
9792
):
93+
if (
94+
"w8a8" in model
95+
or "w8w8" in model
96+
and current_platform.has_device_capability((10, 0))
97+
):
98+
# int8 removed on Blackwell:
99+
pytest.skip("int8 support removed on Blackwell")
100+
98101
with monkeypatch.context():
99102
print(f"MODEL={model}")
100103

@@ -103,14 +106,14 @@ def test_full_graph(
103106

104107
# TODO(luka) add other supported compilation config scenarios here
105108
@pytest.mark.parametrize(
106-
"compilation_config, model_info",
109+
"compilation_config, model, model_kwargs",
107110
[
108111
# additional compile sizes, only some of the models
109112
(
110113
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
111-
model,
114+
*model_info,
112115
)
113-
for model in models_list(all=False)
116+
for model_info in models_list(all=False)
114117
]
115118
+ [
116119
# RMSNorm + quant fusion, only 8-bit quant models
@@ -120,18 +123,19 @@ def test_full_graph(
120123
custom_ops=["+rms_norm"],
121124
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
122125
),
123-
model,
126+
*model_info,
124127
)
125-
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
128+
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
126129
]
127130
+ [
128131
# Test depyf integration works
129132
(
130133
CompilationConfig(
131134
mode=CompilationMode.VLLM_COMPILE,
132-
debug_dump_path=tempfile.gettempdir(),
135+
debug_dump_path=Path(tempfile.gettempdir()),
133136
),
134-
("facebook/opt-125m", {}),
137+
"facebook/opt-125m",
138+
{},
135139
),
136140
]
137141
+ [
@@ -145,24 +149,32 @@ def test_full_graph(
145149
cudagraph_mode=CUDAGraphMode.PIECEWISE,
146150
compile_sizes=[1, 2],
147151
),
148-
model,
152+
*model_info,
149153
)
150-
for model in models_list(all=False)
154+
for model_info in models_list(all=False)
151155
if is_torch_equal_or_newer("2.9.0.dev")
152156
],
153157
)
154158
# only test some of the models
155159
@create_new_process_for_each_test()
156160
def test_custom_compile_config(
157161
compilation_config: CompilationConfig,
158-
model_info: tuple[str, dict[str, Any]],
162+
model: str,
163+
model_kwargs: dict[str, Any],
159164
):
165+
if (
166+
"w8a8" in model
167+
or "w8w8" in model
168+
and current_platform.has_device_capability((10, 0))
169+
):
170+
# int8 removed on Blackwell:
171+
pytest.skip("int8 support removed on Blackwell")
172+
160173
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
161174
"2.9.0.dev"
162175
):
163176
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
164177

165-
model, model_kwargs = model_info
166178
print(f"MODEL={model}")
167179
run_model(compilation_config, model, **model_kwargs)
168180

0 commit comments

Comments
 (0)