1+ import os
12from typing import Any , Dict , List , Optional , Tuple
23
4+ import torch
5+ import torch .distributed as dist
6+
7+ import vllm .envs as envs
38from vllm .executor .executor_base import ExecutorBase
49from vllm .logger import init_logger
510from vllm .utils import get_distributed_init_method , get_ip , get_open_port
@@ -16,7 +21,7 @@ def _init_executor(self) -> None:
1621 """Initialize the worker and load the model.
1722 """
1823 self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
19- rank = 0 )
24+ rpc_rank = 0 )
2025 distributed_init_method = get_distributed_init_method (
2126 get_ip (), get_open_port ())
2227 local_rank = 0
@@ -55,3 +60,77 @@ def check_health(self) -> None:
5560
5661
5762UniProcExecutorAsync = UniProcExecutor
63+
64+
65+ class ExecutorWithExternalLauncher (UniProcExecutor ):
66+ """An executor that uses external launchers to launch engines,
67+ specially designed for torchrun-compatible launchers, for
68+ offline inference with tensor parallelism.
69+
70+ see https://github.com/vllm-project/vllm/issues/11400 for
71+ the motivation, and examples/offline_inference/torchrun_example.py
72+ for the usage example.
73+
74+ The key idea: although it is tensor-parallel inference, we only
75+ create one worker per executor, users will launch multiple
76+ engines with torchrun-compatible launchers, and all these engines
77+ work together to process the same prompts. When scheduling is
78+ deterministic, all the engines will generate the same outputs,
79+ and they don't need to synchronize the states with each other.
80+ """
81+ uses_ray : bool = False
82+
83+ def _init_executor (self ) -> None :
84+ """Initialize the worker and load the model.
85+ """
86+ assert self .vllm_config .parallel_config .pipeline_parallel_size == 1 , \
87+ ("ExecutorWithExternalLauncher does not "
88+ "support pipeline parallelism." )
89+ assert self .vllm_config .scheduler_config .delay_factor == 0.0 , \
90+ ("ExecutorWithExternalLauncher needs deterministic "
91+ "execution, so it"
92+ "does not support delay_factor in scheduling" )
93+ assert not envs .VLLM_USE_V1 , \
94+ ("V1 architecture cannot guarantee deterministic execution, "
95+ "so it is not supported in ExecutorWithExternalLauncher." )
96+ self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
97+ rpc_rank = 0 )
98+ # engines are launched in torchrun-compatible launchers
99+ # so we can use the env:// method.
100+ # required env vars:
101+ # - RANK
102+ # - MASTER_ADDR
103+ # - MASTER_PORT
104+ distributed_init_method = "env://"
105+ rank = int (os .environ ["RANK" ])
106+ local_rank = rank
107+ is_driver_worker = True
108+ kwargs = dict (
109+ vllm_config = self .vllm_config ,
110+ local_rank = local_rank ,
111+ rank = rank ,
112+ distributed_init_method = distributed_init_method ,
113+ is_driver_worker = is_driver_worker ,
114+ )
115+ self .collective_rpc ("init_worker" , args = ([kwargs ], ))
116+ self .collective_rpc ("init_device" )
117+ self .collective_rpc ("load_model" )
118+
119+ def determine_num_available_blocks (self ) -> Tuple [int , int ]:
120+ """
121+ Determine the number of available KV blocks.
122+ Add an additional all_reduce to get the min across all ranks.
123+ Note that even if we have the same `gpu_memory_utilization` and
124+ `swap_space`, the available memory in every rank might still
125+ differ because NCCL can take different amounts of memory in
126+ different ranks. Therefore, it is necessary to test if all ranks
127+ agree on the same KV cache configuration.
128+ """
129+ a , b = super ().determine_num_available_blocks ()
130+ from vllm .distributed .parallel_state import get_world_group
131+ cpu_group = get_world_group ().cpu_group
132+ a_tensor = torch .tensor ([a ], device = "cpu" , dtype = torch .int64 )
133+ b_tensor = torch .tensor ([b ], device = "cpu" , dtype = torch .int64 )
134+ dist .all_reduce (a_tensor , group = cpu_group , op = dist .ReduceOp .MIN )
135+ dist .all_reduce (b_tensor , group = cpu_group , op = dist .ReduceOp .MIN )
136+ return a_tensor .item (), b_tensor .item ()
0 commit comments