Skip to content

Commit 6e64b12

Browse files
DarkLight1337qqma
authored andcommitted
[Frontend] Pass API server count to each process (vllm-project#23717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: qqma <qqma@amazon.com>
1 parent bca6d5d commit 6e64b12

File tree

12 files changed

+221
-51
lines changed

12 files changed

+221
-51
lines changed

benchmarks/kernels/benchmark_w8a8_block_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from typing import Any
1212

1313
import torch
14-
import triton
1514
from tqdm import tqdm
1615

1716
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1817
_w8a8_block_fp8_matmul,
1918
)
2019
from vllm.platforms import current_platform
20+
from vllm.triton_utils import triton
2121
from vllm.utils import FlexibleArgumentParser
2222

2323
mp.set_start_method("spawn", force=True)

examples/others/tensorize_vllm_model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import argparse
5-
import dataclasses
64
import json
75
import logging
86
import os
@@ -327,12 +325,7 @@ def main():
327325

328326

329327
if args.command == "serialize":
330-
eng_args_dict = {f.name: getattr(args, f.name) for f in
331-
dataclasses.fields(EngineArgs)}
332-
333-
engine_args = EngineArgs.from_cli_args(
334-
argparse.Namespace(**eng_args_dict)
335-
)
328+
engine_args = EngineArgs.from_cli_args(args)
336329

337330
input_dir = tensorizer_dir.rstrip('/')
338331
suffix = args.suffix if args.suffix else uuid.uuid4().hex

tests/entrypoints/test_api_server_process_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
6060
global WORKER_RUNTIME_SECONDS
6161
WORKER_RUNTIME_SECONDS = 0.5
6262

63-
# Copy the args to avoid mutating the
63+
# Copy the args to avoid mutating them
6464
args = api_server_args.copy()
6565

6666
if not with_stats_update:

tests/v1/test_external_lb_dp.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import openai # use the official client for correctness check
1010
import pytest
1111
import pytest_asyncio
12+
import requests
1213

1314
from tests.utils import RemoteOpenAIServer
1415
from vllm.platforms import current_platform
@@ -70,6 +71,8 @@ def start_server(r: int, sargs: list[str]):
7071
sargs,
7172
auto_port=False,
7273
env_dict={
74+
"VLLM_SERVER_DEV_MODE":
75+
"1",
7376
current_platform.device_control_env_var:
7477
",".join(
7578
str(
@@ -127,11 +130,19 @@ def default_server_args():
127130

128131

129132
@pytest.fixture(scope="module", params=[1, 4])
130-
def servers(request, default_server_args):
133+
def server_manager(request, default_server_args):
131134
api_server_count = request.param
132-
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
133-
default_server_args) as server_list:
134-
yield server_list
135+
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
136+
api_server_count,
137+
default_server_args)
138+
139+
with server_manager:
140+
yield server_manager
141+
142+
143+
@pytest.fixture
144+
def servers(server_manager):
145+
return server_manager.servers
135146

136147

137148
@pytest_asyncio.fixture
@@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
144155
]
145156

146157

158+
def _get_parallel_config(server: RemoteOpenAIServer):
159+
response = requests.get(server.url_for("server_info?config_format=json"))
160+
response.raise_for_status()
161+
162+
vllm_config = response.json()["vllm_config"]
163+
return vllm_config["parallel_config"]
164+
165+
166+
def test_external_lb_server_info(server_manager):
167+
servers = server_manager.servers
168+
api_server_count = server_manager.api_server_count
169+
170+
for i, (server, _) in enumerate(servers):
171+
print(f"Testing {i=}")
172+
173+
# Each request will hit one of the API servers
174+
# `n_reqs` is set so that there is a good chance each server
175+
# receives at least one request
176+
n_reqs = 2 * api_server_count * api_server_count
177+
parallel_configs = [
178+
_get_parallel_config(server) for _ in range(n_reqs)
179+
]
180+
api_process_counts = [
181+
c["_api_process_count"] for c in parallel_configs
182+
]
183+
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
184+
185+
assert all(c == api_server_count
186+
for c in api_process_counts), api_process_counts
187+
assert all(0 <= r < api_server_count
188+
for r in api_process_ranks), api_process_ranks
189+
190+
147191
@pytest.mark.asyncio
148192
@pytest.mark.parametrize(
149193
"model_name",

tests/v1/test_hybrid_lb_dp.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import openai # use the official client for correctness check
1010
import pytest
1111
import pytest_asyncio
12+
import requests
1213

1314
from tests.utils import RemoteOpenAIServer
1415
from tests.v1.test_utils import check_request_balancing
@@ -92,6 +93,8 @@ def start_server(node: int, sargs: list[str]):
9293
sargs,
9394
auto_port=False,
9495
env_dict={
96+
"VLLM_SERVER_DEV_MODE":
97+
"1",
9598
current_platform.device_control_env_var:
9699
",".join(
97100
str(
@@ -150,12 +153,20 @@ def default_server_args():
150153

151154

152155
@pytest.fixture(scope="module", params=[1, 4])
153-
def servers(request, default_server_args):
156+
def server_manager(request, default_server_args):
154157
api_server_count = request.param
155-
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
156-
default_server_args, DP_SIZE_LOCAL,
157-
TP_SIZE) as server_list:
158-
yield server_list
158+
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
159+
api_server_count,
160+
default_server_args, DP_SIZE_LOCAL,
161+
TP_SIZE)
162+
163+
with server_manager:
164+
yield server_manager
165+
166+
167+
@pytest.fixture
168+
def servers(server_manager):
169+
return server_manager.servers
159170

160171

161172
@pytest_asyncio.fixture
@@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
168179
]
169180

170181

182+
def _get_parallel_config(server: RemoteOpenAIServer):
183+
response = requests.get(server.url_for("server_info?config_format=json"))
184+
response.raise_for_status()
185+
186+
vllm_config = response.json()["vllm_config"]
187+
return vllm_config["parallel_config"]
188+
189+
190+
def test_hybrid_dp_server_info(server_manager):
191+
servers = server_manager.servers
192+
api_server_count = server_manager.api_server_count
193+
194+
for i, (server, _) in enumerate(servers):
195+
print(f"Testing {i=}")
196+
197+
# Each request will hit one of the API servers
198+
# `n_reqs` is set so that there is a good chance each server
199+
# receives at least one request
200+
n_reqs = 2 * api_server_count * api_server_count
201+
parallel_configs = [
202+
_get_parallel_config(server) for _ in range(n_reqs)
203+
]
204+
api_process_counts = [
205+
c["_api_process_count"] for c in parallel_configs
206+
]
207+
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
208+
209+
assert all(c == api_server_count
210+
for c in api_process_counts), api_process_counts
211+
assert all(0 <= r < api_server_count
212+
for r in api_process_ranks), api_process_ranks
213+
214+
171215
@pytest.mark.asyncio
172216
@pytest.mark.parametrize(
173217
"model_name",

tests/v1/test_internal_lb_dp.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import openai # use the official client for correctness check
1111
import pytest
1212
import pytest_asyncio
13+
import requests
1314

1415
from tests.utils import RemoteOpenAIServer
1516
from tests.v1.test_utils import check_request_balancing
@@ -101,6 +102,8 @@ def start_server(sidx: int, r: int, sargs: list[str]):
101102
sargs,
102103
auto_port=False,
103104
env_dict={
105+
"VLLM_SERVER_DEV_MODE":
106+
"1",
104107
current_platform.device_control_env_var:
105108
",".join(
106109
str(
@@ -214,7 +217,10 @@ def start_api_server():
214217
self.model_name,
215218
api_server_args,
216219
auto_port=False,
217-
env_dict={}) # No GPUs needed for API-only server
220+
env_dict={
221+
"VLLM_SERVER_DEV_MODE": "1",
222+
# No GPUs needed for API-only server
223+
})
218224
server.__enter__()
219225
print(f"API-only server started successfully with "
220226
f"{self.api_server_count} API servers")
@@ -293,14 +299,21 @@ def default_server_args():
293299

294300

295301
@pytest.fixture(scope="module", params=[1, 4])
296-
def servers(request, default_server_args):
302+
def server_manager(request, default_server_args):
297303
api_server_count = request.param
298-
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
299-
api_server_count,
300-
default_server_args,
301-
DP_SIZE // NUM_NODES,
302-
TP_SIZE) as server_list:
303-
yield server_list
304+
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
305+
api_server_count,
306+
default_server_args,
307+
DP_SIZE // NUM_NODES,
308+
TP_SIZE)
309+
310+
with server_manager:
311+
yield server_manager
312+
313+
314+
@pytest.fixture
315+
def servers(server_manager):
316+
return server_manager.servers
304317

305318

306319
@pytest.fixture(scope="module", params=[1, 4])
@@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
331344
yield client
332345

333346

347+
def _get_parallel_config(server: RemoteOpenAIServer):
348+
response = requests.get(server.url_for("server_info?config_format=json"))
349+
response.raise_for_status()
350+
351+
vllm_config = response.json()["vllm_config"]
352+
return vllm_config["parallel_config"]
353+
354+
355+
def test_multinode_dp_server_info(server_manager):
356+
head_server = server_manager.servers[0][0]
357+
api_server_count = server_manager.api_server_count
358+
359+
# Each request will hit one of the API servers
360+
# `n_reqs` is set so that there is a good chance each server
361+
# receives at least one request
362+
n_reqs = 2 * api_server_count * api_server_count
363+
parallel_configs = [
364+
_get_parallel_config(head_server) for _ in range(n_reqs)
365+
]
366+
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
367+
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
368+
369+
assert all(c == api_server_count
370+
for c in api_process_counts), api_process_counts
371+
assert all(0 <= r < api_server_count
372+
for r in api_process_ranks), api_process_ranks
373+
374+
334375
@pytest.mark.asyncio
335376
@pytest.mark.parametrize(
336377
"model_name",

vllm/config/parallel.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,25 @@ class is dynamically inherited by the worker class. This is used to inject
193193
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
194194
needs to be divisible by dcp_size."""
195195

196+
_api_process_count: int = 1
197+
"""
198+
The number of API processes initialized.
199+
200+
Note:
201+
This is an internal config that is only valid for and
202+
should only be set by API server scale-out.
203+
"""
204+
205+
_api_process_rank: int = 0
206+
"""
207+
The rank of this API process, or `-1` for engine core processes
208+
under API server scale-out.
209+
210+
Note:
211+
This is an internal config that is only valid for and
212+
should only be set by API server scale-out.
213+
"""
214+
196215
@property
197216
def world_size_across_dp(self) -> int:
198217
"""world_size_across_dp is TPxPPxDP, it is the size of the world
@@ -428,6 +447,12 @@ def __post_init__(self) -> None:
428447
if self.distributed_executor_backend is None and self.world_size == 1:
429448
self.distributed_executor_backend = "uni"
430449

450+
if not -1 <= self._api_process_rank < self._api_process_count:
451+
raise ValueError(
452+
"Invalid value of `_api_process_rank`. "
453+
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
454+
f"but found: {self._api_process_rank}")
455+
431456
@property
432457
def use_ray(self) -> bool:
433458
return self.distributed_executor_backend == "ray" or (

vllm/engine/arg_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ class EngineArgs:
333333
enable_eplb: bool = ParallelConfig.enable_eplb
334334
expert_placement_strategy: ExpertPlacementStrategy = \
335335
ParallelConfig.expert_placement_strategy
336+
_api_process_count: int = ParallelConfig._api_process_count
337+
_api_process_rank: int = ParallelConfig._api_process_rank
336338
num_redundant_experts: int = EPLBConfig.num_redundant_experts
337339
eplb_window_size: int = EPLBConfig.window_size
338340
eplb_step_interval: int = EPLBConfig.step_interval
@@ -952,7 +954,10 @@ def from_cli_args(cls, args: argparse.Namespace):
952954
# Get the list of attributes of this dataclass.
953955
attrs = [attr.name for attr in dataclasses.fields(cls)]
954956
# Set the attributes from the parsed arguments.
955-
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
957+
engine_args = cls(**{
958+
attr: getattr(args, attr)
959+
for attr in attrs if hasattr(args, attr)
960+
})
956961
return engine_args
957962

958963
def create_model_config(self) -> ModelConfig:
@@ -1366,6 +1371,8 @@ def create_engine_config(
13661371
worker_cls=self.worker_cls,
13671372
worker_extension_cls=self.worker_extension_cls,
13681373
decode_context_parallel_size=self.decode_context_parallel_size,
1374+
_api_process_count=self._api_process_count,
1375+
_api_process_rank=self._api_process_rank,
13691376
)
13701377

13711378
speculative_config = self.create_speculative_config(

0 commit comments

Comments
 (0)