Skip to content

Commit 0cdd213

Browse files
authored
[Misc] Improve Worker process title and logging prefix (#22205)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent 948dd34 commit 0cdd213

File tree

4 files changed

+37
-23
lines changed

4 files changed

+37
-23
lines changed

vllm/utils/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,23 +3359,19 @@ def has_triton_kernels() -> bool:
33593359

33603360
def set_process_title(name: str,
33613361
suffix: str = "",
3362-
append: bool = False) -> None:
3362+
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:
33633363
"""
33643364
Set the current process title to a specific name with an
33653365
optional suffix.
33663366
33673367
Args:
33683368
name: The title to assign to the current process.
33693369
suffix: An optional suffix to append to the base name.
3370-
append: Whether to append to the existing process title.
3370+
prefix: A prefix to prepend to the front separated by `::`.
33713371
"""
33723372
if suffix:
33733373
name = f"{name}_{suffix}"
3374-
if append:
3375-
name = f"{setproctitle.getproctitle()}_{name}"
3376-
else:
3377-
name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}"
3378-
setproctitle.setproctitle(name)
3374+
setproctitle.setproctitle(f"{prefix}::{name}")
33793375

33803376

33813377
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:

vllm/v1/engine/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
224224

225225
def add_request(self, request: Request, request_wave: int = 0):
226226
"""Add request to the scheduler.
227-
227+
228228
`request_wave`: indicate which wave of requests this is expected to
229229
belong to in DP case
230230
"""
@@ -433,7 +433,7 @@ def save_tensorized_model(
433433
def preprocess_add_request(
434434
self, request: EngineCoreRequest) -> tuple[Request, int]:
435435
"""Preprocess the request.
436-
436+
437437
This function could be directly used in input processing thread to allow
438438
request initialization running in parallel with Model forward
439439
"""
@@ -697,7 +697,7 @@ def signal_handler(signum, frame):
697697
parallel_config: ParallelConfig = kwargs[
698698
"vllm_config"].parallel_config
699699
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
700-
set_process_title("DPEngineCore", str(dp_rank))
700+
set_process_title("EngineCore", f"DP{dp_rank}")
701701
decorate_logs()
702702
# Set data parallel rank for this engine process.
703703
parallel_config.data_parallel_rank = dp_rank

vllm/v1/engine/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
local_dp_ranks.append(local_index)
117117
self.processes.append(
118118
context.Process(target=target_fn,
119-
name=f"EngineCore_{global_index}",
119+
name=f"EngineCore_DP{global_index}",
120120
kwargs=common_kwargs | {
121121
"dp_rank": global_index,
122122
"local_dp_rank": local_index,

vllm/v1/executor/multiproc_executor.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
2727
MessageQueue)
2828
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
29+
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
30+
get_pp_group, get_tp_group)
2931
from vllm.executor.multiproc_worker_utils import (
3032
set_multiprocessing_worker_envs)
3133
from vllm.logger import init_logger
@@ -397,17 +399,6 @@ def __init__(
397399
wrapper.init_worker(all_kwargs)
398400
self.worker = wrapper
399401

400-
pp_size = vllm_config.parallel_config.pipeline_parallel_size
401-
tp_size = vllm_config.parallel_config.tensor_parallel_size
402-
pp_str = f"PP{rank // tp_size}" if pp_size > 1 else ""
403-
tp_str = f"TP{rank % tp_size}" if tp_size > 1 else ""
404-
suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}"
405-
process_name = "VllmWorker"
406-
if suffix:
407-
set_process_title(suffix, append=True)
408-
process_name = f"{process_name} {suffix}"
409-
decorate_logs(process_name)
410-
411402
# Initialize MessageQueue for receiving SchedulerOutput
412403
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
413404
input_shm_handle, self.worker.rank)
@@ -425,8 +416,14 @@ def __init__(
425416
name="WorkerAsyncOutputCopy")
426417
self.async_output_copy_thread.start()
427418

428-
# Initialize device and loads weights
419+
# Initialize device
429420
self.worker.init_device()
421+
422+
# Set process title and log prefix
423+
self.setup_proc_title_and_log_prefix(
424+
enable_ep=vllm_config.parallel_config.enable_expert_parallel)
425+
426+
# Load model
430427
self.worker.load_model()
431428

432429
@staticmethod
@@ -663,3 +660,24 @@ def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
663660

664661
if output_rank is None or self.rank == output_rank:
665662
self.handle_output(output)
663+
664+
@staticmethod
665+
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
666+
dp_size = get_dp_group().world_size
667+
dp_rank = get_dp_group().rank_in_group
668+
pp_size = get_pp_group().world_size
669+
pp_rank = get_pp_group().rank_in_group
670+
tp_size = get_tp_group().world_size
671+
tp_rank = get_tp_group().rank_in_group
672+
process_name = "Worker"
673+
if dp_size > 1:
674+
process_name += f"_DP{dp_rank}"
675+
if pp_size > 1:
676+
process_name += f"_PP{pp_rank}"
677+
if tp_size > 1:
678+
process_name += f"_TP{tp_rank}"
679+
if enable_ep:
680+
ep_rank = get_ep_group().rank_in_group
681+
process_name += f"_EP{ep_rank}"
682+
set_process_title(name=process_name)
683+
decorate_logs(process_name)

0 commit comments

Comments
 (0)