|
3 | 3 |
|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | | -import itertools |
7 | | -import logging |
8 | 6 | import tempfile |
9 | | -from collections.abc import Iterable |
10 | | -from typing import Any, Optional, Union |
| 7 | +from typing import Any, Union |
11 | 8 |
|
12 | 9 | import pytest |
13 | | -import regex as re |
14 | 10 | import torch |
15 | 11 |
|
16 | 12 | from tests.quantization.utils import is_quant_method_supported |
17 | 13 | from vllm import LLM, SamplingParams |
18 | | -from vllm.attention.backends.registry import _Backend |
19 | 14 | from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig |
20 | 15 | from vllm.platforms import current_platform |
21 | 16 | from vllm.utils import is_torch_equal_or_newer |
22 | | -from vllm.utils.flashinfer import has_flashinfer |
23 | 17 |
|
24 | | -from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test |
| 18 | +from ..utils import create_new_process_for_each_test |
25 | 19 |
|
26 | 20 |
|
27 | 21 | def models_list(*, all: bool = True, keywords: list[str] | None = None): |
@@ -189,194 +183,6 @@ def test_fp8_kv_scale_compile(optimization_level: int): |
189 | 183 | run_model(optimization_level, model, **model_kwargs) |
190 | 184 |
|
191 | 185 |
|
192 | | -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] |
193 | | -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] |
194 | | -MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only |
195 | | - |
196 | | -if current_platform.is_cuda(): |
197 | | - MODELS_FP8 += [ |
198 | | - ( |
199 | | - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", |
200 | | - {"max_model_len": 1024}, |
201 | | - _Backend.TRITON_ATTN, |
202 | | - ) |
203 | | - ] |
204 | | - |
205 | | - if current_platform.is_device_capability((10, 0)): |
206 | | - MODELS_FP8 += [ |
207 | | - ( |
208 | | - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", |
209 | | - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, |
210 | | - _Backend.FLASHINFER, |
211 | | - ) |
212 | | - ] |
213 | | - |
214 | | - MODELS_FP4 += [ |
215 | | - ( |
216 | | - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", |
217 | | - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, |
218 | | - _Backend.FLASHINFER, |
219 | | - ) |
220 | | - ] |
221 | | - |
222 | | - MODELS += [ |
223 | | - ( |
224 | | - "meta-llama/Llama-3.1-8B-Instruct", |
225 | | - {"max_model_len": 1024}, |
226 | | - _Backend.FLASHINFER, |
227 | | - ) |
228 | | - ] |
229 | | - |
230 | | -elif current_platform.is_rocm(): |
231 | | - MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] |
232 | | - |
233 | | -INDUCTOR_GRAPH_PARTITION = ( |
234 | | - [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] |
235 | | -) |
236 | | - |
237 | | -# TODO(luka) test both in nightly |
238 | | -CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] |
239 | | - |
240 | | - |
241 | | -@pytest.mark.parametrize( |
242 | | - "model_name, model_kwargs, backend, custom_ops", |
243 | | - # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 |
244 | | - list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) |
245 | | - # quant_fp4 only has the custom impl |
246 | | - + list(flat_product(MODELS_FP4, [""])), |
247 | | -) |
248 | | -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) |
249 | | -def test_e2e_fusion_attn_quant( |
250 | | - model_name: str, |
251 | | - model_kwargs: dict[str, Any], |
252 | | - backend: _Backend, |
253 | | - custom_ops: str, |
254 | | - inductor_graph_partition: bool, |
255 | | - caplog_mp_spawn, |
256 | | - monkeypatch, |
257 | | -): |
258 | | - custom_ops_list = custom_ops.split(",") if custom_ops else [] |
259 | | - |
260 | | - if inductor_graph_partition: |
261 | | - mode = CUDAGraphMode.FULL_AND_PIECEWISE |
262 | | - splitting_ops: Optional[list[str]] = None |
263 | | - else: |
264 | | - mode = CUDAGraphMode.FULL_DECODE_ONLY |
265 | | - splitting_ops = [] |
266 | | - |
267 | | - # Disable, compile cache to make sure custom passes run. |
268 | | - # Otherwise, we can't verify fusion happened through the logs. |
269 | | - # Log capture also doesn't work with multiprocessing yet. |
270 | | - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") |
271 | | - |
272 | | - # To capture subprocess logs, we need to know whether spawn or fork is used. |
273 | | - # Force spawn as it is more general. |
274 | | - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") |
275 | | - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) |
276 | | - |
277 | | - compilation_config = CompilationConfig( |
278 | | - # Testing properties |
279 | | - custom_ops=custom_ops_list, |
280 | | - use_inductor_graph_partition=inductor_graph_partition, |
281 | | - cudagraph_mode=mode, |
282 | | - splitting_ops=splitting_ops, |
283 | | - # Common |
284 | | - level=CompilationLevel.PIECEWISE, |
285 | | - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), |
286 | | - # Inductor caches custom passes by default as well via uuid |
287 | | - inductor_compile_config={"force_disable_caches": True}, |
288 | | - ) |
289 | | - |
290 | | - with caplog_mp_spawn(logging.DEBUG) as log_holder: |
291 | | - run_model(compilation_config, model_name, **model_kwargs) |
292 | | - |
293 | | - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text |
294 | | - |
295 | | - |
296 | | -# TODO(luka) test both in nightly |
297 | | -# TODO(luka) change to - |
298 | | -CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] |
299 | | - |
300 | | - |
301 | | -def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: |
302 | | - for op_list in itertools.product(*custom_ops_lists): |
303 | | - yield ",".join(op_list) |
304 | | - |
305 | | - |
306 | | -@multi_gpu_test(num_gpus=2) |
307 | | -@pytest.mark.parametrize( |
308 | | - "model_name, model_kwargs, backend, custom_ops", |
309 | | - # Toggle RMSNorm and QuantFP8 for FP8 models |
310 | | - list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) |
311 | | - # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO |
312 | | - # Toggle RMSNorm for FP4 models and unquant models |
313 | | - + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), |
314 | | -) |
315 | | -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) |
316 | | -@pytest.mark.skipif( |
317 | | - not current_platform.is_cuda() |
318 | | - or not has_flashinfer() |
319 | | - or not current_platform.has_device_capability(90), |
320 | | - reason="allreduce+rmsnorm fusion requires flashinfer", |
321 | | -) |
322 | | -def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( |
323 | | - model_name, |
324 | | - model_kwargs, |
325 | | - backend, |
326 | | - custom_ops: str, |
327 | | - inductor_graph_partition: bool, |
328 | | - caplog_mp_spawn, |
329 | | - monkeypatch, |
330 | | -): |
331 | | - custom_ops_list = custom_ops.split(",") if custom_ops else [] |
332 | | - |
333 | | - if inductor_graph_partition: |
334 | | - mode = CUDAGraphMode.FULL_AND_PIECEWISE |
335 | | - splitting_ops: Optional[list[str]] = None |
336 | | - else: |
337 | | - mode = CUDAGraphMode.FULL_DECODE_ONLY |
338 | | - splitting_ops = [] |
339 | | - |
340 | | - # Disable, compile cache to make sure custom passes run. |
341 | | - # Otherwise, we can't verify fusion happened through the logs. |
342 | | - # Log capture also doesn't work with multiprocessing yet. |
343 | | - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") |
344 | | - |
345 | | - # To capture subprocess logs, we need to know whether spawn or fork is used. |
346 | | - # Force spawn as it is more general. |
347 | | - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") |
348 | | - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) |
349 | | - |
350 | | - compilation_config = CompilationConfig( |
351 | | - # Testing properties |
352 | | - use_inductor_graph_partition=inductor_graph_partition, |
353 | | - cudagraph_mode=mode, |
354 | | - custom_ops=custom_ops_list, |
355 | | - splitting_ops=splitting_ops, |
356 | | - # Common |
357 | | - level=CompilationLevel.PIECEWISE, |
358 | | - pass_config=PassConfig( |
359 | | - enable_attn_fusion=True, |
360 | | - enable_noop=True, |
361 | | - enable_fi_allreduce_fusion=True, |
362 | | - ), |
363 | | - # Inductor caches custom passes by default as well via uuid |
364 | | - inductor_compile_config={"force_disable_caches": True}, |
365 | | - ) |
366 | | - |
367 | | - with caplog_mp_spawn(logging.DEBUG) as log_holder: |
368 | | - run_model( |
369 | | - compilation_config, model_name, tensor_parallel_size=2, **model_kwargs |
370 | | - ) |
371 | | - |
372 | | - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text |
373 | | - |
374 | | - matches = re.findall( |
375 | | - r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text |
376 | | - ) |
377 | | - assert len(matches) == 2, log_holder.text |
378 | | - |
379 | | - |
380 | 186 | def run_model( |
381 | 187 | compile_config: Union[int, CompilationConfig], model: str, **model_kwargs |
382 | 188 | ): |
|
0 commit comments