Skip to content

Commit c76e8e5

Browse files
committed
[Perf] API-server scaleout with all-to-all server-engine comms
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 24b2e1e commit c76e8e5

File tree

16 files changed

+1018
-383
lines changed

16 files changed

+1018
-383
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def make_request(request_id,
4444
multi_modal_placeholders=mm_positions,
4545
sampling_params=SamplingParams(max_tokens=17),
4646
eos_token_id=100,
47-
arrival_time=0,
4847
lora_request=None,
4948
cache_salt=cache_salt,
5049
)

tests/v1/core/test_prefix_caching.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def make_request(request_id,
3838
sampling_params=SamplingParams(max_tokens=17,
3939
prompt_logprobs=prompt_logprobs),
4040
eos_token_id=100,
41-
arrival_time=0,
4241
lora_request=None,
4342
cache_salt=cache_salt,
4443
)

tests/v1/core/test_scheduler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
138138
multi_modal_placeholders=mm_position,
139139
multi_modal_hashes=None,
140140
eos_token_id=EOS_TOKEN_ID,
141-
arrival_time=0,
142141
)
143142
requests.append(request)
144143
return requests
@@ -732,7 +731,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
732731
prompt_logprobs_dict={},
733732
)
734733
engine_core_outputs = scheduler.update_from_output(output,
735-
model_runner_output)
734+
model_runner_output)[0]
736735

737736
for i in range(len(requests)):
738737
running_req = scheduler.running[i]

vllm/entrypoints/cli/serve.py

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import argparse
4+
import multiprocessing
5+
import os
46
import signal
7+
import sys
8+
from multiprocessing.context import SpawnProcess
9+
from typing import Any
510

611
import uvloop
12+
import zmq
713

814
import vllm.envs as envs
915
from vllm import AsyncEngineArgs
1016
from vllm.entrypoints.cli.types import CLISubcommand
11-
from vllm.entrypoints.openai.api_server import run_server
17+
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
18+
setup_server)
1219
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
1320
validate_parsed_serve_args)
21+
from vllm.executor.multiproc_worker_utils import _add_prefix
1422
from vllm.logger import init_logger
1523
from vllm.usage.usage_lib import UsageContext
16-
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
24+
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
25+
from vllm.v1.engine.coordinator import DPCoordinator
1726
from vllm.v1.engine.core import EngineCoreProc
1827
from vllm.v1.engine.core_client import CoreEngineProcManager
1928
from vllm.v1.executor.abstract import Executor
29+
from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr,
30+
wait_for_engine_startup)
2031

2132
logger = init_logger(__name__)
2233

@@ -34,9 +45,12 @@ def cmd(args: argparse.Namespace) -> None:
3445
if hasattr(args, 'model_tag') and args.model_tag is not None:
3546
args.model = args.model_tag
3647

37-
if args.headless:
48+
if args.headless or args.api_server_count < 1:
3849
run_headless(args)
50+
elif args.api_server_count > 1:
51+
run_multi_api_server(args)
3952
else:
53+
# Single API server (this process).
4054
uvloop.run(run_server(args))
4155

4256
def validate(self, args: argparse.Namespace) -> None:
@@ -67,6 +81,11 @@ def subparser_init(
6781
type=int,
6882
default=0,
6983
help='Starting data parallel rank for secondary nodes.')
84+
serve_parser.add_argument('--api-server-count',
85+
'-asc',
86+
type=int,
87+
default=1,
88+
help='How many API server processes to run.')
7089
serve_parser.add_argument(
7190
"--config",
7291
type=str,
@@ -86,6 +105,9 @@ def cmd_init() -> list[CLISubcommand]:
86105

87106
def run_headless(args: argparse.Namespace):
88107

108+
if args.api_server_count > 1:
109+
raise RuntimeError("api_server_count can't be set in headless mode")
110+
89111
# Create the EngineConfig.
90112
engine_args = AsyncEngineArgs.from_cli_args(args)
91113
usage_context = UsageContext.OPENAI_API_SERVER
@@ -98,7 +120,7 @@ def run_headless(args: argparse.Namespace):
98120
local_engine_count = parallel_config.data_parallel_size_local
99121
host = parallel_config.data_parallel_master_ip
100122
port = engine_args.data_parallel_rpc_port # add to config too
101-
input_address = get_tcp_uri(host, port)
123+
handshake_address = get_tcp_uri(host, port)
102124

103125
if local_engine_count <= 0:
104126
raise RuntimeError("data_parallel_size_local must be > 0 in "
@@ -114,7 +136,7 @@ def signal_handler(signum, frame):
114136

115137
logger.info(
116138
"Launching %d data parallel engine(s) in headless mode, "
117-
"with head node address %s.", local_engine_count, input_address)
139+
"with head node address %s.", local_engine_count, handshake_address)
118140

119141
# Create the engines.
120142
engine_manager = CoreEngineProcManager(
@@ -124,7 +146,7 @@ def signal_handler(signum, frame):
124146
local_start_index=0,
125147
vllm_config=vllm_config,
126148
on_head_node=False,
127-
input_address=input_address,
149+
handshake_address=handshake_address,
128150
executor_class=Executor.get_class(vllm_config),
129151
log_stats=not engine_args.disable_log_stats,
130152
)
@@ -134,3 +156,128 @@ def signal_handler(signum, frame):
134156
finally:
135157
logger.info("Shutting down.")
136158
engine_manager.close()
159+
160+
161+
def run_multi_api_server(args: argparse.Namespace):
162+
163+
assert not args.headless
164+
num_api_servers = args.api_server_count
165+
# assert num_api_servers > 1
166+
167+
listen_address, sock = setup_server(args)
168+
169+
engine_args = AsyncEngineArgs.from_cli_args(args)
170+
usage_context = UsageContext.OPENAI_API_SERVER
171+
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
172+
parallel_config = vllm_config.parallel_config
173+
174+
assert parallel_config.data_parallel_rank == 0
175+
176+
dp_size = parallel_config.data_parallel_size
177+
local_engine_count = parallel_config.data_parallel_size_local
178+
host = parallel_config.data_parallel_master_ip
179+
local_only = local_engine_count == dp_size
180+
181+
# Set up input and output addresses.
182+
input_addresses = [
183+
get_engine_client_zmq_addr(local_only, host)
184+
for _ in range(num_api_servers)
185+
]
186+
output_addresses = [
187+
get_engine_client_zmq_addr(local_only, host)
188+
for _ in range(num_api_servers)
189+
]
190+
191+
addresses: dict[str, Any] = {
192+
"input_addresses": input_addresses,
193+
"output_addresses": output_addresses,
194+
}
195+
196+
# Set up coordinator for dp > 1.
197+
coordinator = None
198+
stats_update_address = None
199+
if dp_size > 1:
200+
# TODO "ready" event for coordinator
201+
coordinator = DPCoordinator(parallel_config)
202+
addresses.update(coordinator.get_engine_socket_addresses())
203+
stats_update_address = coordinator.get_stats_publish_address()
204+
205+
handshake_address = get_engine_client_zmq_addr(
206+
local_only, host, parallel_config.data_parallel_rpc_port)
207+
208+
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
209+
bind=True) as handshake_socket:
210+
211+
# Start local engines.
212+
if not local_engine_count:
213+
local_engine_manager = None
214+
else:
215+
local_engine_manager = CoreEngineProcManager(
216+
EngineCoreProc.run_engine_core,
217+
vllm_config=vllm_config,
218+
executor_class=Executor.get_class(vllm_config),
219+
log_stats=not engine_args.disable_log_stats,
220+
handshake_address=handshake_address,
221+
on_head_node=True,
222+
local_engine_count=local_engine_count,
223+
start_index=0,
224+
local_start_index=0)
225+
226+
# Start API servers.
227+
spawn_context = multiprocessing.get_context("spawn")
228+
api_server_workers: list[SpawnProcess] = []
229+
for i, in_addr, out_addr in zip(range(num_api_servers),
230+
input_addresses, output_addresses):
231+
client_config = {
232+
"input_address": in_addr,
233+
"output_address": out_addr,
234+
"client_index": i
235+
}
236+
if stats_update_address is not None:
237+
client_config["stats_update_address"] = stats_update_address
238+
239+
# TODO check signal propagation
240+
proc = spawn_context.Process(target=run_api_server_worker,
241+
name=f"ApiServer_{i}",
242+
args=(listen_address, sock, args,
243+
client_config))
244+
api_server_workers.append(proc)
245+
proc.start()
246+
247+
# Wait for engine handshakes to complete.
248+
core_engines = [
249+
CoreEngine(index=i, local=(i < local_engine_count))
250+
for i in range(dp_size)
251+
]
252+
253+
wait_for_engine_startup(
254+
handshake_socket,
255+
addresses,
256+
core_engines,
257+
parallel_config,
258+
vllm_config.cache_config,
259+
local_engine_manager,
260+
coordinator.proc if coordinator else None,
261+
)
262+
263+
# TODO handle failures / clean shutdown here
264+
for proc in api_server_workers:
265+
proc.join()
266+
267+
268+
def run_api_server_worker(listen_address,
269+
sock,
270+
args,
271+
client_config=None,
272+
**uvicorn_kwargs) -> None:
273+
274+
# Add process-specific prefix to stdout and stderr.
275+
from multiprocessing import current_process
276+
process_name = current_process().name
277+
pid = os.getpid()
278+
_add_prefix(sys.stdout, process_name, pid)
279+
_add_prefix(sys.stderr, process_name, pid)
280+
281+
uvloop.run(
282+
run_server_worker(listen_address, sock, args, client_config,
283+
**uvicorn_kwargs))

0 commit comments

Comments
 (0)