|
20 | 20 | Run `pytest tests/compile/test_aclgraph.py`. |
21 | 21 | """ |
22 | 22 |
|
23 | | -import multiprocessing |
24 | 23 | import os |
25 | | -from unittest.mock import patch |
26 | 24 |
|
27 | 25 | import pytest |
28 | 26 | import torch |
|
32 | 30 | from tests.model_utils import check_outputs_equal |
33 | 31 |
|
34 | 32 | MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] |
35 | | -PROMPTS = [ |
36 | | - "Hello, my name is", |
37 | | - "The president of the United States is", |
38 | | - "The capital of France is", |
39 | | - "The future of AI is", |
40 | | -] |
41 | | - |
42 | | -original_replay = torch.npu.NPUGraph.replay |
43 | | -replay_counter = multiprocessing.Value("i", 0) |
44 | | - |
45 | | - |
46 | | -def replay_wrapper(self): |
47 | | - with replay_counter.get_lock(): |
48 | | - replay_counter.value += 1 |
49 | | - return original_replay(self) |
50 | 33 |
|
51 | 34 |
|
52 | 35 | @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", |
53 | 36 | reason="aclgraph only support on v1") |
54 | 37 | @pytest.mark.parametrize("model", MODELS) |
55 | 38 | @pytest.mark.parametrize("max_tokens", [32]) |
56 | | -@pytest.mark.parametrize("prompts", [PROMPTS, PROMPTS[:3]]) |
57 | 39 | def test_models( |
58 | 40 | model: str, |
59 | 41 | max_tokens: int, |
60 | | - prompts: list[str], |
61 | 42 | monkeypatch: pytest.MonkeyPatch, |
62 | 43 | ) -> None: |
63 | 44 | with monkeypatch.context() as m: |
| 45 | + prompts = [ |
| 46 | + "Hello, my name is", "The president of the United States is", |
| 47 | + "The capital of France is", "The future of AI is" |
| 48 | + ] |
| 49 | + |
64 | 50 | # aclgraph only support on v1 |
65 | 51 | m.setenv("VLLM_USE_V1", "1") |
66 | 52 |
|
67 | 53 | sampling_params = SamplingParams(max_tokens=max_tokens, |
68 | 54 | temperature=0.0) |
69 | | - |
70 | 55 | # TODO: change to use vllmrunner when the registry of custom op is solved |
71 | 56 | # while running pytest |
72 | | - with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper): |
73 | | - vllm_model = LLM(model) |
74 | | - vllm_aclgraph_outputs = vllm_model.generate( |
75 | | - prompts, sampling_params) |
76 | | - |
77 | | - num_hidden_layers = vllm_model.llm_engine.model_config.hf_config.num_hidden_layers |
78 | | - |
79 | | - # Calculate expected replay call count |
80 | | - # Number of ACL graphs = hidden layers + 1 (only for piecewise scenario) |
81 | | - num_acl_graphs = num_hidden_layers + 1 |
82 | | - # Number of inference steps (first step only includes one prompt, hence +1) |
83 | | - num_inference_steps = max_tokens + 1 |
84 | | - expected_replay_calls = num_acl_graphs * num_inference_steps |
85 | | - |
86 | | - # Verify replay call count |
87 | | - actual_replay_calls = replay_counter.value |
88 | | - assert actual_replay_calls == expected_replay_calls, ( |
89 | | - f"NPUGraph.replay call count mismatch. " |
90 | | - f"Expected: {expected_replay_calls}, Actual: {actual_replay_calls}" |
91 | | - ) |
92 | | - |
| 57 | + vllm_model = LLM(model) |
| 58 | + vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params) |
93 | 59 | del vllm_model |
94 | 60 | torch.npu.empty_cache() |
95 | 61 |
|
|
0 commit comments