Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,7 @@ steps:
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/distributed/test_sequence_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py

Expand Down
183 changes: 151 additions & 32 deletions tests/compile/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@
from ..utils import flat_product, multi_gpu_test


class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0


class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
attention_fusions: int
allreduce_fusions: int | None = None
matches: Matches


MODELS_FP8: list[ModelBackendTestCase] = []
Expand All @@ -40,15 +46,23 @@ class ModelBackendTestCase(NamedTuple):
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=190,
),
),
]

Expand All @@ -57,8 +71,12 @@ class ModelBackendTestCase(NamedTuple):
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=190,
),
),
]

Expand All @@ -68,8 +86,12 @@ class ModelBackendTestCase(NamedTuple):
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
matches=Matches(
attention_fusion=0,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
]

Expand All @@ -79,19 +101,19 @@ class ModelBackendTestCase(NamedTuple):
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
]

Expand All @@ -100,8 +122,7 @@ class ModelBackendTestCase(NamedTuple):


@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
Expand All @@ -112,8 +133,7 @@ def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
Expand Down Expand Up @@ -163,12 +183,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)

matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
assert len(log_matches) == 1, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion


# TODO(luka) test both in nightly
Expand All @@ -182,8 +202,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
Expand All @@ -204,8 +223,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
Expand Down Expand Up @@ -253,23 +271,124 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text

assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion

log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text

assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
model_name: str,
model_kwargs: dict,
backend: _Backend,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if current_platform.is_device_capability((10, 0)):
# TODO: https://github.com/vllm-project/vllm/issues/27893
pytest.skip("Blackwell is not supported for AsyncTP pass")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")

custom_ops_list = custom_ops.split(",") if custom_ops else []

if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []

# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)

compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)

with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text

assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion

log_matches = re.findall(
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text

assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
assert int(log_matches[0]) == matches.sequence_parallel
assert int(log_matches[1]) == matches.sequence_parallel

matches = re.findall(
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text

assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
assert int(log_matches[0]) == matches.async_tp
assert int(log_matches[1]) == matches.async_tp


def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
Expand Down
Loading