|
20 | 20 | Run `pytest tests/compile/test_aclgraph.py`. |
21 | 21 | """ |
22 | 22 |
|
| 23 | +import multiprocessing |
23 | 24 | import os |
| 25 | +from unittest.mock import patch |
24 | 26 |
|
25 | 27 | import pytest |
26 | 28 | import torch |
|
30 | 32 | from tests.model_utils import check_outputs_equal |
31 | 33 |
|
32 | 34 | MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] |
| 35 | +PROMPTS = [ |
| 36 | + "Hello, my name is", |
| 37 | + "The president of the United States is", |
| 38 | + "The capital of France is", |
| 39 | + "The future of AI is", |
| 40 | +] |
| 41 | + |
| 42 | +original_replay = torch.npu.NPUGraph.replay |
| 43 | +replay_counter = multiprocessing.Value("i", 0) |
| 44 | + |
| 45 | + |
| 46 | +def replay_wrapper(self): |
| 47 | + with replay_counter.get_lock(): |
| 48 | + replay_counter.value += 1 |
| 49 | + return original_replay(self) |
33 | 50 |
|
34 | 51 |
|
35 | 52 | @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", |
36 | 53 | reason="aclgraph only support on v1") |
37 | 54 | @pytest.mark.parametrize("model", MODELS) |
38 | 55 | @pytest.mark.parametrize("max_tokens", [32]) |
| 56 | +@pytest.mark.parametrize("prompts", [PROMPTS, PROMPTS[:3]]) |
39 | 57 | def test_models( |
40 | 58 | model: str, |
41 | 59 | max_tokens: int, |
| 60 | + prompts: list[str], |
42 | 61 | monkeypatch: pytest.MonkeyPatch, |
43 | 62 | ) -> None: |
44 | 63 | with monkeypatch.context() as m: |
45 | | - prompts = [ |
46 | | - "Hello, my name is", "The president of the United States is", |
47 | | - "The capital of France is", "The future of AI is" |
48 | | - ] |
49 | | - |
50 | 64 | # aclgraph only support on v1 |
51 | 65 | m.setenv("VLLM_USE_V1", "1") |
52 | 66 |
|
53 | 67 | sampling_params = SamplingParams(max_tokens=max_tokens, |
54 | 68 | temperature=0.0) |
| 69 | + |
55 | 70 | # TODO: change to use vllmrunner when the registry of custom op is solved |
56 | 71 | # while running pytest |
57 | | - vllm_model = LLM(model) |
58 | | - vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params) |
| 72 | + with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper): |
| 73 | + vllm_model = LLM(model) |
| 74 | + vllm_aclgraph_outputs = vllm_model.generate( |
| 75 | + prompts, sampling_params) |
| 76 | + |
| 77 | + num_hidden_layers = vllm_model.llm_engine.model_config.hf_config.num_hidden_layers |
| 78 | + |
| 79 | + # Calculate expected replay call count |
| 80 | + # Number of ACL graphs = hidden layers + 1 (only for piecewise scenario) |
| 81 | + num_acl_graphs = num_hidden_layers + 1 |
| 82 | + # Number of inference steps (first step only includes one prompt, hence +1) |
| 83 | + num_inference_steps = max_tokens + 1 |
| 84 | + expected_replay_calls = num_acl_graphs * num_inference_steps |
| 85 | + |
| 86 | + # Verify replay call count |
| 87 | + actual_replay_calls = replay_counter.value |
| 88 | + assert actual_replay_calls == expected_replay_calls, ( |
| 89 | + f"NPUGraph.replay call count mismatch. " |
| 90 | + f"Expected: {expected_replay_calls}, Actual: {actual_replay_calls}" |
| 91 | + ) |
| 92 | + |
59 | 93 | del vllm_model |
60 | 94 | torch.npu.empty_cache() |
61 | 95 |
|
|
0 commit comments