Skip to content

Commit 9ca44ce

Browse files
committed
[V1] AsyncLLM data parallel WIP
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 4cb6fa0 commit 9ca44ce

File tree

6 files changed

+193
-44
lines changed

6 files changed

+193
-44
lines changed

vllm/config.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm.transformers_utils.s3_utils import S3Model
3838
from vllm.transformers_utils.utils import is_s3
3939
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
40-
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
40+
get_cpu_memory, get_open_port, random_uuid,
41+
resolve_obj_by_qualname)
4142

4243
if TYPE_CHECKING:
4344
from ray.util.placement_group import PlacementGroup
@@ -1423,10 +1424,19 @@ def __post_init__(self) -> None:
14231424
self.world_size = self.pipeline_parallel_size * \
14241425
self.tensor_parallel_size
14251426

1426-
self.data_parallel_size = envs.VLLM_DP_SIZE
1427-
self.data_parallel_rank = envs.VLLM_DP_RANK
1428-
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
1429-
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
1427+
if self.data_parallel_size > 1:
1428+
import os
1429+
if os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") != "1":
1430+
raise ValueError(
1431+
"VLLM_ENABLE_V1_MULTIPROCESSING can't be disabled when "
1432+
"using data parallel.")
1433+
self.data_parallel_master_port = get_open_port()
1434+
else:
1435+
self.data_parallel_size = envs.VLLM_DP_SIZE
1436+
self.data_parallel_rank = envs.VLLM_DP_RANK
1437+
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
1438+
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
1439+
14301440
self.world_size_across_dp = self.world_size * self.data_parallel_size
14311441

14321442
if self.distributed_executor_backend == "external_launcher":

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class EngineArgs:
113113
# number of P/D disaggregation (or other disaggregation) workers
114114
pipeline_parallel_size: int = 1
115115
tensor_parallel_size: int = 1
116+
data_parallel_size: int = 1
116117
max_parallel_loading_workers: Optional[int] = None
117118
block_size: Optional[int] = None
118119
enable_prefix_caching: Optional[bool] = None
@@ -430,6 +431,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
430431
type=int,
431432
default=EngineArgs.tensor_parallel_size,
432433
help='Number of tensor parallel replicas.')
434+
parser.add_argument('--data-parallel-size',
435+
'-dp',
436+
type=int,
437+
default=EngineArgs.data_parallel_size,
438+
help='Number of data parallel replicas.')
433439
parser.add_argument(
434440
'--max-parallel-loading-workers',
435441
type=int,
@@ -1170,6 +1176,7 @@ def create_engine_config(self,
11701176
parallel_config = ParallelConfig(
11711177
pipeline_parallel_size=self.pipeline_parallel_size,
11721178
tensor_parallel_size=self.tensor_parallel_size,
1179+
data_parallel_size=self.data_parallel_size,
11731180
max_parallel_loading_workers=self.max_parallel_loading_workers,
11741181
disable_custom_all_reduce=self.disable_custom_all_reduce,
11751182
tokenizer_pool_config=TokenizerPoolConfig.create_config(

vllm/v1/core/scheduler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
logger = init_logger(__name__)
2222

23+
# Used to trigger dummy requests whose outputs should be ignored.
24+
DUMMY_REQ_ID = "__DUMMY_REQ_ID"
25+
2326

2427
class Scheduler:
2528

@@ -483,6 +486,7 @@ def update_from_output(
483486

484487
new_running: List[Request] = []
485488
outputs: List[EngineCoreOutput] = []
489+
finished_requests: List[str] = []
486490

487491
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
488492
# loop can be a performance bottleneck. We should do our best to avoid
@@ -564,17 +568,21 @@ def update_from_output(
564568
new_logprobs = logprobs.slice(req_index, req_index + 1)
565569

566570
# Transmit partial if chunked prefill & prompt logprobs is enabled
567-
if new_token_ids or prompt_logprobs_tensors is not None:
571+
if (new_token_ids or prompt_logprobs_tensors is not None) \
572+
and req_id != DUMMY_REQ_ID:
568573
# Add EngineCoreOutput for this Request.
574+
finish_reason = request.get_finished_reason()
569575
outputs.append(
570576
EngineCoreOutput(
571577
request_id=req_id,
572578
new_token_ids=new_token_ids,
573-
finish_reason=request.get_finished_reason(),
579+
finish_reason=finish_reason,
574580
new_logprobs=new_logprobs,
575581
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
576582
stop_reason=request.stop_reason,
577583
events=request.take_events()))
584+
if finish_reason:
585+
finished_requests.append(req_id)
578586

579587
self.scheduled_req_ids.remove(request.request_id)
580588
if not stopped:
@@ -583,6 +591,7 @@ def update_from_output(
583591
self.running = new_running
584592
return EngineCoreOutputs(
585593
outputs=outputs,
594+
finished_requests=finished_requests,
586595
scheduler_stats=self.make_stats(),
587596
)
588597

@@ -653,7 +662,7 @@ def get_num_unfinished_requests(self) -> int:
653662
return len(self.waiting) + len(self.running)
654663

655664
def has_unfinished_requests(self) -> bool:
656-
return self.get_num_unfinished_requests() > 0
665+
return len(self.running) > 0 or len(self.waiting) > 0
657666

658667
def get_num_unscheduled_requests(self) -> int:
659668
"""Number of requests that are not being processed by the executor."""

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class EngineCoreOutputs(
133133
timestamp: float = 0.0
134134

135135
utility_output: Optional[UtilityOutput] = None
136+
finished_requests: List[str] = []
136137

137138
def __post_init__(self):
138139
if self.timestamp == 0.0:

vllm/v1/engine/core.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ def __init__(
268268
ready_pipe.send({"status": "READY"})
269269

270270
@staticmethod
271-
def run_engine_core(*args, **kwargs):
271+
def run_engine_core(*args,
272+
vllm_config: VllmConfig,
273+
dp_rank: int = 0,
274+
**kwargs):
272275
"""Launch EngineCore busy loop in background process."""
273276

274277
# Signal handler used for graceful termination.
@@ -289,6 +292,9 @@ def signal_handler(signum, frame):
289292
signal.signal(signal.SIGTERM, signal_handler)
290293
signal.signal(signal.SIGINT, signal_handler)
291294

295+
# Set data parallel rank for this engine process.
296+
vllm_config.parallel_config.data_parallel_rank = dp_rank
297+
292298
parent_process = psutil.Process().parent()
293299
engine_core = None
294300
try:
@@ -313,11 +319,17 @@ def run_busy_loop(self):
313319
step_fn = (self.step
314320
if self.batch_queue is None else self.step_with_batch_queue)
315321

322+
dp_idle_mode = False
323+
316324
# Loop until process is sent a SIGINT or SIGTERM
317325
while True:
318326
# 1) Poll the input queue until there is work to do.
319327
if not self.scheduler.has_unfinished_requests():
320328
while True:
329+
if dp_idle_mode and self.input_queue.empty():
330+
# TODO if time has passed here, break to log stats
331+
self.execute_dummy_batch()
332+
continue
321333
try:
322334
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
323335
self._handle_client_request(*req)
@@ -327,14 +339,16 @@ def run_busy_loop(self):
327339
# Break out the loop so we can log_stats in step().
328340
if self.log_stats:
329341
break
330-
except BaseException:
331-
raise
332342

333343
# 2) Handle any new client requests.
334344
while not self.input_queue.empty():
335345
req = self.input_queue.get_nowait()
336346
self._handle_client_request(*req)
337347

348+
if self.scheduler.has_unfinished_requests():
349+
# TODO client to reset this in coordinated way
350+
dp_idle_mode = True
351+
338352
# 3) Step the engine core.
339353
outputs = step_fn()
340354

0 commit comments

Comments
 (0)