Skip to content

Commit c0c4aae

Browse files
committed
[CI] improve test_aclgraph.py and make sure replay is called as expected when inferencing
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 89a388b commit c0c4aae

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
Run `pytest tests/compile/test_aclgraph.py`.
2121
"""
2222

23+
import multiprocessing
2324
import os
25+
from unittest.mock import patch
2426

2527
import pytest
2628
import torch
@@ -30,32 +32,64 @@
3032
from tests.model_utils import check_outputs_equal
3133

3234
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
35+
PROMPTS = [
36+
"Hello, my name is",
37+
"The president of the United States is",
38+
"The capital of France is",
39+
"The future of AI is",
40+
]
41+
42+
original_replay = torch.npu.NPUGraph.replay
43+
replay_counter = multiprocessing.Value("i", 0)
44+
45+
46+
def replay_wrapper(self):
47+
with replay_counter.get_lock():
48+
replay_counter.value += 1
49+
return original_replay(self)
3350

3451

3552
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3653
reason="aclgraph only support on v1")
3754
@pytest.mark.parametrize("model", MODELS)
3855
@pytest.mark.parametrize("max_tokens", [32])
56+
@pytest.mark.parametrize("prompts", [PROMPTS, PROMPTS[:3]])
3957
def test_models(
4058
model: str,
4159
max_tokens: int,
60+
prompts: list[str],
4261
monkeypatch: pytest.MonkeyPatch,
4362
) -> None:
4463
with monkeypatch.context() as m:
45-
prompts = [
46-
"Hello, my name is", "The president of the United States is",
47-
"The capital of France is", "The future of AI is"
48-
]
49-
5064
# aclgraph only support on v1
5165
m.setenv("VLLM_USE_V1", "1")
5266

5367
sampling_params = SamplingParams(max_tokens=max_tokens,
5468
temperature=0.0)
69+
5570
# TODO: change to use vllmrunner when the registry of custom op is solved
5671
# while running pytest
57-
vllm_model = LLM(model)
58-
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
72+
with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper):
73+
vllm_model = LLM(model)
74+
vllm_aclgraph_outputs = vllm_model.generate(
75+
prompts, sampling_params)
76+
77+
num_hidden_layers = vllm_model.llm_engine.model_config.hf_config.num_hidden_layers
78+
79+
# Calculate expected replay call count
80+
# Number of ACL graphs = hidden layers + 1 (only for piecewise scenario)
81+
num_acl_graphs = num_hidden_layers + 1
82+
# Number of inference steps (first step only includes one prompt, hence +1)
83+
num_inference_steps = max_tokens + 1
84+
expected_replay_calls = num_acl_graphs * num_inference_steps
85+
86+
# Verify replay call count
87+
actual_replay_calls = replay_counter.value
88+
assert actual_replay_calls == expected_replay_calls, (
89+
f"NPUGraph.replay call count mismatch. "
90+
f"Expected: {expected_replay_calls}, Actual: {actual_replay_calls}"
91+
)
92+
5993
del vllm_model
6094
torch.npu.empty_cache()
6195

0 commit comments

Comments
 (0)