Skip to content

Commit 2dbe8c0

Browse files
authored
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
1 parent 84ec470 commit 2dbe8c0

26 files changed

+1837
-445
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,11 @@ steps:
618618
- vllm/worker/model_runner.py
619619
- entrypoints/llm/test_collective_rpc.py
620620
- tests/v1/test_async_llm_dp.py
621+
- tests/v1/entrypoints/openai/test_multi_api_servers.py
621622
- vllm/v1/engine/
622623
commands:
623624
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
625+
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
624626
- pytest -v -s entrypoints/llm/test_collective_rpc.py
625627
- pytest -v -s ./compile/test_basic_correctness.py
626628
- pytest -v -s ./compile/test_wrapper.py
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import multiprocessing
4+
import socket
5+
import threading
6+
import time
7+
from typing import Optional
8+
from unittest.mock import patch
9+
10+
import pytest
11+
12+
from vllm.v1.utils import (APIServerProcessManager,
13+
wait_for_completion_or_failure)
14+
15+
# Global variables to control worker behavior
16+
WORKER_RUNTIME_SECONDS = 0.5
17+
18+
19+
# Mock implementation of run_api_server_worker
20+
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
21+
"""Mock run_api_server_worker that runs for a specific time."""
22+
print(f"Mock worker started with client_config: {client_config}")
23+
time.sleep(WORKER_RUNTIME_SECONDS)
24+
print("Mock worker completed successfully")
25+
26+
27+
@pytest.fixture
28+
def api_server_args():
29+
"""Fixture to provide arguments for APIServerProcessManager."""
30+
sock = socket.socket()
31+
return {
32+
"target_server_fn":
33+
mock_run_api_server_worker,
34+
"listen_address":
35+
"localhost:8000",
36+
"sock":
37+
sock,
38+
"args":
39+
"test_args", # Simple string to avoid pickling issues
40+
"num_servers":
41+
3,
42+
"input_addresses": [
43+
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
44+
"tcp://127.0.0.1:5003"
45+
],
46+
"output_addresses": [
47+
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
48+
"tcp://127.0.0.1:6003"
49+
],
50+
"stats_update_address":
51+
"tcp://127.0.0.1:7000",
52+
}
53+
54+
55+
@pytest.mark.parametrize("with_stats_update", [True, False])
56+
def test_api_server_process_manager_init(api_server_args, with_stats_update):
57+
"""Test initializing the APIServerProcessManager."""
58+
# Set the worker runtime to ensure tests complete in reasonable time
59+
global WORKER_RUNTIME_SECONDS
60+
WORKER_RUNTIME_SECONDS = 0.5
61+
62+
# Copy the args to avoid mutating the
63+
args = api_server_args.copy()
64+
65+
if not with_stats_update:
66+
args.pop("stats_update_address")
67+
manager = APIServerProcessManager(**args)
68+
69+
try:
70+
# Verify the manager was initialized correctly
71+
assert len(manager.processes) == 3
72+
73+
# Verify all processes are running
74+
for proc in manager.processes:
75+
assert proc.is_alive()
76+
77+
print("Waiting for processes to run...")
78+
time.sleep(WORKER_RUNTIME_SECONDS / 2)
79+
80+
# They should still be alive at this point
81+
for proc in manager.processes:
82+
assert proc.is_alive()
83+
84+
finally:
85+
# Always clean up the processes
86+
print("Cleaning up processes...")
87+
manager.close()
88+
89+
# Give processes time to terminate
90+
time.sleep(0.2)
91+
92+
# Verify all processes were terminated
93+
for proc in manager.processes:
94+
assert not proc.is_alive()
95+
96+
97+
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
98+
mock_run_api_server_worker)
99+
def test_wait_for_completion_or_failure(api_server_args):
100+
"""Test that wait_for_completion_or_failure works with failures."""
101+
global WORKER_RUNTIME_SECONDS
102+
WORKER_RUNTIME_SECONDS = 1.0
103+
104+
# Create the manager
105+
manager = APIServerProcessManager(**api_server_args)
106+
107+
try:
108+
assert len(manager.processes) == 3
109+
110+
# Create a result capture for the thread
111+
result: dict[str, Optional[Exception]] = {"exception": None}
112+
113+
def run_with_exception_capture():
114+
try:
115+
wait_for_completion_or_failure(api_server_manager=manager)
116+
except Exception as e:
117+
result["exception"] = e
118+
119+
# Start a thread to run wait_for_completion_or_failure
120+
wait_thread = threading.Thread(target=run_with_exception_capture,
121+
daemon=True)
122+
wait_thread.start()
123+
124+
# Let all processes run for a short time
125+
time.sleep(0.2)
126+
127+
# All processes should still be running
128+
assert all(proc.is_alive() for proc in manager.processes)
129+
130+
# Now simulate a process failure
131+
print("Simulating process failure...")
132+
manager.processes[0].terminate()
133+
134+
# Wait for the wait_for_completion_or_failure
135+
# to detect and handle the failure
136+
# This should trigger it to terminate all other processes
137+
wait_thread.join(timeout=1.0)
138+
139+
# The wait thread should have exited
140+
assert not wait_thread.is_alive()
141+
142+
# Verify that an exception was raised with appropriate error message
143+
assert result["exception"] is not None
144+
assert "died with exit code" in str(result["exception"])
145+
146+
# All processes should now be terminated
147+
for i, proc in enumerate(manager.processes):
148+
assert not proc.is_alive(), f"Process {i} should not be alive"
149+
150+
finally:
151+
manager.close()
152+
time.sleep(0.2)
153+
154+
155+
@pytest.mark.timeout(30)
156+
def test_normal_completion(api_server_args):
157+
"""Test that wait_for_completion_or_failure works in normal completion."""
158+
global WORKER_RUNTIME_SECONDS
159+
WORKER_RUNTIME_SECONDS = 0.1
160+
161+
# Create the manager
162+
manager = APIServerProcessManager(**api_server_args)
163+
164+
try:
165+
# Give processes time to terminate
166+
# wait for processes to complete
167+
remaining_processes = manager.processes.copy()
168+
while remaining_processes:
169+
for proc in remaining_processes:
170+
if not proc.is_alive():
171+
remaining_processes.remove(proc)
172+
time.sleep(0.1)
173+
174+
# Verify all processes have terminated
175+
for i, proc in enumerate(manager.processes):
176+
assert not proc.is_alive(
177+
), f"Process {i} still alive after terminate()"
178+
179+
# Now call wait_for_completion_or_failure
180+
# since all processes have already
181+
# terminated, it should return immediately
182+
# with no error
183+
wait_for_completion_or_failure(api_server_manager=manager)
184+
185+
finally:
186+
# Clean up just in case
187+
manager.close()
188+
time.sleep(0.2)
189+
190+
191+
@pytest.mark.timeout(30)
192+
def test_external_process_monitoring(api_server_args):
193+
"""Test that wait_for_completion_or_failure handles additional processes."""
194+
global WORKER_RUNTIME_SECONDS
195+
WORKER_RUNTIME_SECONDS = 100
196+
197+
# Create and start the external process
198+
# (simulates local_engine_manager or coordinator)
199+
spawn_context = multiprocessing.get_context("spawn")
200+
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
201+
name="MockExternalProcess")
202+
external_proc.start()
203+
204+
# Create the class to simulate a coordinator
205+
class MockCoordinator:
206+
207+
def __init__(self, proc):
208+
self.proc = proc
209+
210+
def close(self):
211+
if self.proc.is_alive():
212+
self.proc.terminate()
213+
self.proc.join(timeout=0.5)
214+
215+
# Create a mock coordinator with the external process
216+
mock_coordinator = MockCoordinator(external_proc)
217+
218+
# Create the API server manager
219+
manager = APIServerProcessManager(**api_server_args)
220+
221+
try:
222+
# Verify manager initialization
223+
assert len(manager.processes) == 3
224+
225+
# Create a result capture for the thread
226+
result: dict[str, Optional[Exception]] = {"exception": None}
227+
228+
def run_with_exception_capture():
229+
try:
230+
wait_for_completion_or_failure(api_server_manager=manager,
231+
coordinator=mock_coordinator)
232+
except Exception as e:
233+
result["exception"] = e
234+
235+
# Start a thread to run wait_for_completion_or_failure
236+
wait_thread = threading.Thread(target=run_with_exception_capture,
237+
daemon=True)
238+
wait_thread.start()
239+
240+
# Terminate the external process to trigger a failure
241+
time.sleep(0.2)
242+
external_proc.terminate()
243+
244+
# Wait for the thread to detect the failure
245+
wait_thread.join(timeout=1.0)
246+
247+
# The wait thread should have completed
248+
assert not wait_thread.is_alive(
249+
), "wait_for_completion_or_failure thread still running"
250+
251+
# Verify that an exception was raised with appropriate error message
252+
assert result["exception"] is not None, "No exception was raised"
253+
error_message = str(result["exception"])
254+
assert "died with exit code" in error_message, \
255+
f"Unexpected error message: {error_message}"
256+
assert "MockExternalProcess" in error_message, \
257+
f"Error doesn't mention external process: {error_message}"
258+
259+
# Verify that all API server processes were terminated as a result
260+
for i, proc in enumerate(manager.processes):
261+
assert not proc.is_alive(
262+
), f"API server process {i} was not terminated"
263+
264+
finally:
265+
# Clean up
266+
manager.close()
267+
mock_coordinator.close()
268+
time.sleep(0.2)

tests/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.distributed import (ensure_model_parallel_initialized,
2929
init_distributed_environment)
3030
from vllm.engine.arg_utils import AsyncEngineArgs
31-
from vllm.entrypoints.openai.cli_args import make_arg_parser
31+
from vllm.entrypoints.cli.serve import ServeSubcommand
3232
from vllm.model_executor.model_loader import get_model_loader
3333
from vllm.platforms import current_platform
3434
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -99,7 +99,8 @@ def __init__(self,
9999

100100
parser = FlexibleArgumentParser(
101101
description="vLLM's remote OpenAI server.")
102-
parser = make_arg_parser(parser)
102+
subparsers = parser.add_subparsers(required=False, dest="subparser")
103+
parser = ServeSubcommand().subparser_init(subparsers)
103104
args = parser.parse_args(["--model", model, *vllm_serve_args])
104105
self.host = str(args.host or 'localhost')
105106
self.port = int(args.port)

tests/v1/core/test_kv_cache_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def make_request(request_id,
4545
multi_modal_placeholders=mm_positions,
4646
sampling_params=SamplingParams(max_tokens=17),
4747
eos_token_id=100,
48-
arrival_time=0,
4948
lora_request=None,
5049
cache_salt=cache_salt,
5150
)

tests/v1/core/test_prefix_caching.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def make_request(request_id,
3838
sampling_params=SamplingParams(max_tokens=17,
3939
prompt_logprobs=prompt_logprobs),
4040
eos_token_id=100,
41-
arrival_time=0,
4241
lora_request=None,
4342
cache_salt=cache_salt,
4443
)

tests/v1/core/test_scheduler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
138138
multi_modal_placeholders=mm_position,
139139
multi_modal_hashes=None,
140140
eos_token_id=EOS_TOKEN_ID,
141-
arrival_time=0,
142141
)
143142
requests.append(request)
144143
return requests
@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
744743
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
745744

746745
# No draft or accepted tokens counted yet
747-
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
746+
assert not engine_core_outputs or (
747+
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
748748

749749
# Schedule the speculated tokens for validation
750750
output = scheduler.schedule()
@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
772772
engine_core_outputs = scheduler.update_from_output(output,
773773
model_runner_output)
774774

775-
scheduler_stats = engine_core_outputs.scheduler_stats
775+
scheduler_stats = engine_core_outputs[0].scheduler_stats \
776+
if engine_core_outputs else None
776777
if expected[0] == 0:
777778
assert scheduler_stats.spec_decoding_stats is None
778779
else:
@@ -843,7 +844,7 @@ def _step_until_done(
843844
# We should be in the decode phase now.
844845
assert num_scheduled_tokens == 1
845846
assert len(output.kv_connector_metadata.requests) == 0
846-
ecos = scheduler.update_from_output(output, model_runner_output)
847+
ecos = scheduler.update_from_output(output, model_runner_output)[0]
847848
all_done = True
848849
for eco in ecos.outputs:
849850
if eco.finish_reason is None:

0 commit comments

Comments
 (0)