Skip to content

Commit cf2bccc

Browse files
committed
Revert "[CI] improve test_aclgraph.py and make sure replay is called as expected when inferencing"
This reverts commit 36519ff. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent fd97314 commit cf2bccc

File tree

1 file changed

+7
-41
lines changed

1 file changed

+7
-41
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
Run `pytest tests/compile/test_aclgraph.py`.
2121
"""
2222

23-
import multiprocessing
2423
import os
25-
from unittest.mock import patch
2624

2725
import pytest
2826
import torch
@@ -32,64 +30,32 @@
3230
from tests.model_utils import check_outputs_equal
3331

3432
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
35-
PROMPTS = [
36-
"Hello, my name is",
37-
"The president of the United States is",
38-
"The capital of France is",
39-
"The future of AI is",
40-
]
41-
42-
original_replay = torch.npu.NPUGraph.replay
43-
replay_counter = multiprocessing.Value("i", 0)
44-
45-
46-
def replay_wrapper(self):
47-
with replay_counter.get_lock():
48-
replay_counter.value += 1
49-
return original_replay(self)
5033

5134

5235
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
5336
reason="aclgraph only support on v1")
5437
@pytest.mark.parametrize("model", MODELS)
5538
@pytest.mark.parametrize("max_tokens", [32])
56-
@pytest.mark.parametrize("prompts", [PROMPTS, PROMPTS[:3]])
5739
def test_models(
5840
model: str,
5941
max_tokens: int,
60-
prompts: list[str],
6142
monkeypatch: pytest.MonkeyPatch,
6243
) -> None:
6344
with monkeypatch.context() as m:
45+
prompts = [
46+
"Hello, my name is", "The president of the United States is",
47+
"The capital of France is", "The future of AI is"
48+
]
49+
6450
# aclgraph only support on v1
6551
m.setenv("VLLM_USE_V1", "1")
6652

6753
sampling_params = SamplingParams(max_tokens=max_tokens,
6854
temperature=0.0)
69-
7055
# TODO: change to use vllmrunner when the registry of custom op is solved
7156
# while running pytest
72-
with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper):
73-
vllm_model = LLM(model)
74-
vllm_aclgraph_outputs = vllm_model.generate(
75-
prompts, sampling_params)
76-
77-
num_hidden_layers = vllm_model.llm_engine.model_config.hf_config.num_hidden_layers
78-
79-
# Calculate expected replay call count
80-
# Number of ACL graphs = hidden layers + 1 (only for piecewise scenario)
81-
num_acl_graphs = num_hidden_layers + 1
82-
# Number of inference steps (first step only includes one prompt, hence +1)
83-
num_inference_steps = max_tokens + 1
84-
expected_replay_calls = num_acl_graphs * num_inference_steps
85-
86-
# Verify replay call count
87-
actual_replay_calls = replay_counter.value
88-
assert actual_replay_calls == expected_replay_calls, (
89-
f"NPUGraph.replay call count mismatch. "
90-
f"Expected: {expected_replay_calls}, Actual: {actual_replay_calls}"
91-
)
92-
57+
vllm_model = LLM(model)
58+
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
9359
del vllm_model
9460
torch.npu.empty_cache()
9561

0 commit comments

Comments
 (0)