11# SPDX-License-Identifier: Apache-2.0
22
33import argparse
4+ import multiprocessing
5+ import os
46import signal
7+ import sys
8+ from multiprocessing .context import SpawnProcess
9+ from typing import Any
510
611import uvloop
12+ import zmq
713
814import vllm .envs as envs
915from vllm import AsyncEngineArgs
1016from vllm .entrypoints .cli .types import CLISubcommand
11- from vllm .entrypoints .openai .api_server import run_server
17+ from vllm .entrypoints .openai .api_server import (run_server , run_server_worker ,
18+ setup_server )
1219from vllm .entrypoints .openai .cli_args import (make_arg_parser ,
1320 validate_parsed_serve_args )
21+ from vllm .executor .multiproc_worker_utils import _add_prefix
1422from vllm .logger import init_logger
1523from vllm .usage .usage_lib import UsageContext
16- from vllm .utils import FlexibleArgumentParser , get_tcp_uri
24+ from vllm .utils import FlexibleArgumentParser , get_tcp_uri , zmq_socket_ctx
25+ from vllm .v1 .engine .coordinator import DPCoordinator
1726from vllm .v1 .engine .core import EngineCoreProc
1827from vllm .v1 .engine .core_client import CoreEngineProcManager
1928from vllm .v1 .executor .abstract import Executor
29+ from vllm .v1 .utils import (CoreEngine , get_engine_client_zmq_addr ,
30+ wait_for_engine_startup )
2031
2132logger = init_logger (__name__ )
2233
@@ -34,9 +45,12 @@ def cmd(args: argparse.Namespace) -> None:
3445 if hasattr (args , 'model_tag' ) and args .model_tag is not None :
3546 args .model = args .model_tag
3647
37- if args .headless :
48+ if args .headless or args . api_server_count < 1 :
3849 run_headless (args )
50+ elif args .api_server_count > 1 :
51+ run_multi_api_server (args )
3952 else :
53+ # Single API server (this process).
4054 uvloop .run (run_server (args ))
4155
4256 def validate (self , args : argparse .Namespace ) -> None :
@@ -67,6 +81,11 @@ def subparser_init(
6781 type = int ,
6882 default = 0 ,
6983 help = 'Starting data parallel rank for secondary nodes.' )
84+ serve_parser .add_argument ('--api-server-count' ,
85+ '-asc' ,
86+ type = int ,
87+ default = 1 ,
88+ help = 'How many API server processes to run.' )
7089 serve_parser .add_argument (
7190 "--config" ,
7291 type = str ,
@@ -86,6 +105,9 @@ def cmd_init() -> list[CLISubcommand]:
86105
87106def run_headless (args : argparse .Namespace ):
88107
108+ if args .api_server_count > 1 :
109+ raise RuntimeError ("api_server_count can't be set in headless mode" )
110+
89111 # Create the EngineConfig.
90112 engine_args = AsyncEngineArgs .from_cli_args (args )
91113 usage_context = UsageContext .OPENAI_API_SERVER
@@ -98,7 +120,7 @@ def run_headless(args: argparse.Namespace):
98120 local_engine_count = parallel_config .data_parallel_size_local
99121 host = parallel_config .data_parallel_master_ip
100122 port = engine_args .data_parallel_rpc_port # add to config too
101- input_address = get_tcp_uri (host , port )
123+ handshake_address = get_tcp_uri (host , port )
102124
103125 if local_engine_count <= 0 :
104126 raise RuntimeError ("data_parallel_size_local must be > 0 in "
@@ -114,7 +136,7 @@ def signal_handler(signum, frame):
114136
115137 logger .info (
116138 "Launching %d data parallel engine(s) in headless mode, "
117- "with head node address %s." , local_engine_count , input_address )
139+ "with head node address %s." , local_engine_count , handshake_address )
118140
119141 # Create the engines.
120142 engine_manager = CoreEngineProcManager (
@@ -124,7 +146,7 @@ def signal_handler(signum, frame):
124146 local_start_index = 0 ,
125147 vllm_config = vllm_config ,
126148 on_head_node = False ,
127- input_address = input_address ,
149+ handshake_address = handshake_address ,
128150 executor_class = Executor .get_class (vllm_config ),
129151 log_stats = not engine_args .disable_log_stats ,
130152 )
@@ -134,3 +156,128 @@ def signal_handler(signum, frame):
134156 finally :
135157 logger .info ("Shutting down." )
136158 engine_manager .close ()
159+
160+
161+ def run_multi_api_server (args : argparse .Namespace ):
162+
163+ assert not args .headless
164+ num_api_servers = args .api_server_count
165+ # assert num_api_servers > 1
166+
167+ listen_address , sock = setup_server (args )
168+
169+ engine_args = AsyncEngineArgs .from_cli_args (args )
170+ usage_context = UsageContext .OPENAI_API_SERVER
171+ vllm_config = engine_args .create_engine_config (usage_context = usage_context )
172+ parallel_config = vllm_config .parallel_config
173+
174+ assert parallel_config .data_parallel_rank == 0
175+
176+ dp_size = parallel_config .data_parallel_size
177+ local_engine_count = parallel_config .data_parallel_size_local
178+ host = parallel_config .data_parallel_master_ip
179+ local_only = local_engine_count == dp_size
180+
181+ # Set up input and output addresses.
182+ input_addresses = [
183+ get_engine_client_zmq_addr (local_only , host )
184+ for _ in range (num_api_servers )
185+ ]
186+ output_addresses = [
187+ get_engine_client_zmq_addr (local_only , host )
188+ for _ in range (num_api_servers )
189+ ]
190+
191+ addresses : dict [str , Any ] = {
192+ "input_addresses" : input_addresses ,
193+ "output_addresses" : output_addresses ,
194+ }
195+
196+ # Set up coordinator for dp > 1.
197+ coordinator = None
198+ stats_update_address = None
199+ if dp_size > 1 :
200+ # TODO "ready" event for coordinator
201+ coordinator = DPCoordinator (parallel_config )
202+ addresses .update (coordinator .get_engine_socket_addresses ())
203+ stats_update_address = coordinator .get_stats_publish_address ()
204+
205+ handshake_address = get_engine_client_zmq_addr (
206+ local_only , host , parallel_config .data_parallel_rpc_port )
207+
208+ with zmq_socket_ctx (handshake_address , zmq .ROUTER ,
209+ bind = True ) as handshake_socket :
210+
211+ # Start local engines.
212+ if not local_engine_count :
213+ local_engine_manager = None
214+ else :
215+ local_engine_manager = CoreEngineProcManager (
216+ EngineCoreProc .run_engine_core ,
217+ vllm_config = vllm_config ,
218+ executor_class = Executor .get_class (vllm_config ),
219+ log_stats = not engine_args .disable_log_stats ,
220+ handshake_address = handshake_address ,
221+ on_head_node = True ,
222+ local_engine_count = local_engine_count ,
223+ start_index = 0 ,
224+ local_start_index = 0 )
225+
226+ # Start API servers.
227+ spawn_context = multiprocessing .get_context ("spawn" )
228+ api_server_workers : list [SpawnProcess ] = []
229+ for i , in_addr , out_addr in zip (range (num_api_servers ),
230+ input_addresses , output_addresses ):
231+ client_config = {
232+ "input_address" : in_addr ,
233+ "output_address" : out_addr ,
234+ "client_index" : i
235+ }
236+ if stats_update_address is not None :
237+ client_config ["stats_update_address" ] = stats_update_address
238+
239+ # TODO check signal propagation
240+ proc = spawn_context .Process (target = run_api_server_worker ,
241+ name = f"ApiServer_{ i } " ,
242+ args = (listen_address , sock , args ,
243+ client_config ))
244+ api_server_workers .append (proc )
245+ proc .start ()
246+
247+ # Wait for engine handshakes to complete.
248+ core_engines = [
249+ CoreEngine (index = i , local = (i < local_engine_count ))
250+ for i in range (dp_size )
251+ ]
252+
253+ wait_for_engine_startup (
254+ handshake_socket ,
255+ addresses ,
256+ core_engines ,
257+ parallel_config ,
258+ vllm_config .cache_config ,
259+ local_engine_manager ,
260+ coordinator .proc if coordinator else None ,
261+ )
262+
263+ # TODO handle failures / clean shutdown here
264+ for proc in api_server_workers :
265+ proc .join ()
266+
267+
268+ def run_api_server_worker (listen_address ,
269+ sock ,
270+ args ,
271+ client_config = None ,
272+ ** uvicorn_kwargs ) -> None :
273+
274+ # Add process-specific prefix to stdout and stderr.
275+ from multiprocessing import current_process
276+ process_name = current_process ().name
277+ pid = os .getpid ()
278+ _add_prefix (sys .stdout , process_name , pid )
279+ _add_prefix (sys .stderr , process_name , pid )
280+
281+ uvloop .run (
282+ run_server_worker (listen_address , sock , args , client_config ,
283+ ** uvicorn_kwargs ))
0 commit comments