55import os
66from collections import defaultdict
77from dataclasses import dataclass
8- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
8+ from itertools import islice , repeat
9+ from typing import (TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple ,
10+ Union )
911
10- import cloudpickle
1112import msgspec
1213
1314import vllm .envs as envs
@@ -212,17 +213,15 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
212213 num_gpus = num_gpus ,
213214 scheduling_strategy = scheduling_strategy ,
214215 ** ray_remote_kwargs ,
215- )(RayWorkerWrapper ).remote (vllm_config = self .vllm_config ,
216- rpc_rank = rank )
216+ )(RayWorkerWrapper ).remote (vllm_config = self .vllm_config )
217217 else :
218218 worker = ray .remote (
219219 num_cpus = 0 ,
220220 num_gpus = 0 ,
221221 resources = {current_platform .ray_device_key : num_gpus },
222222 scheduling_strategy = scheduling_strategy ,
223223 ** ray_remote_kwargs ,
224- )(RayWorkerWrapper ).remote (vllm_config = self .vllm_config ,
225- rpc_rank = rank )
224+ )(RayWorkerWrapper ).remote (vllm_config = self .vllm_config )
226225 worker_metadata .append (
227226 RayWorkerMetaData (worker = worker , created_rank = rank ))
228227
@@ -244,7 +243,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
244243 # as the resource holder for the driver process.
245244 self .driver_dummy_worker = worker
246245 self .driver_worker = RayWorkerWrapper (
247- vllm_config = self .vllm_config , rpc_rank = 0 )
246+ vllm_config = self .vllm_config )
248247 worker_metadata .pop (i )
249248 break
250249
@@ -283,11 +282,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
283282 for i , item in enumerate (sorted_worker_metadata ):
284283 item .adjusted_rank = i + start_rank
285284 self .workers = [item .worker for item in sorted_worker_metadata ]
286- rerank_mapping = {
287- item .created_rank : item .adjusted_rank
288- for item in sorted_worker_metadata
289- }
290- self ._run_workers ("adjust_rank" , rerank_mapping )
291285
292286 # Get the set of GPU IDs used on each node.
293287 worker_node_and_gpu_ids = []
@@ -328,10 +322,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
328322 " each node." )
329323
330324 # Set environment variables for the driver and workers.
331- all_args_to_update_environment_variables = [{
325+ all_args_to_update_environment_variables = [( {
332326 current_platform .device_control_env_var :
333327 "," .join (map (str , node_gpus [node_id ])),
334- } for (node_id , _ ) in worker_node_and_gpu_ids ]
328+ }, ) for (node_id , _ ) in worker_node_and_gpu_ids ]
335329
336330 # Environment variables to copy from driver to workers
337331 env_vars_to_copy = [
@@ -347,7 +341,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
347341 # TODO: refactor platform-specific env vars
348342 for name in env_vars_to_copy :
349343 if name in os .environ :
350- args [name ] = os .environ [name ]
344+ args [0 ][ name ] = os .environ [name ]
351345
352346 logger .info ("non_carry_over_env_vars from config: %s" ,
353347 self .non_carry_over_env_vars )
@@ -362,7 +356,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
362356 all_args_to_update_environment_variables )
363357
364358 self ._run_workers ("update_environment_variables" ,
365- self ._get_env_vars_to_be_updated ())
359+ all_args = self ._get_env_vars_to_be_updated ())
366360
367361 if len (node_gpus ) == 1 :
368362 # in single node case, we don't need to get the IP address.
@@ -390,7 +384,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
390384 or (rank % self .parallel_config .tensor_parallel_size == 0 ),
391385 )
392386 all_kwargs .append (kwargs )
393- self ._run_workers ("init_worker" , all_kwargs )
387+ self ._run_workers ("init_worker" , all_kwargs = all_kwargs )
394388
395389 self ._run_workers ("init_device" )
396390 self ._run_workers ("load_model" ,
@@ -466,6 +460,8 @@ def _run_workers(
466460 method : Union [str , Callable ],
467461 * args ,
468462 async_run_tensor_parallel_workers_only : bool = False ,
463+ all_args : Optional [List [Tuple [Any , ...]]] = None ,
464+ all_kwargs : Optional [List [Dict [str , Any ]]] = None ,
469465 max_concurrent_workers : Optional [int ] = None ,
470466 ** kwargs ,
471467 ) -> Any :
@@ -479,11 +475,6 @@ def _run_workers(
479475 rather than blocking on the results.
480476 - args/kwargs: All workers share the same args/kwargs
481477 """
482- if isinstance (method , str ):
483- sent_method = method
484- else :
485- sent_method = cloudpickle .dumps (method )
486- del method
487478 if self .use_ray_spmd_worker :
488479 assert not async_run_tensor_parallel_workers_only , (
489480 "async_run_tensor_parallel_workers_only is not supported for "
@@ -493,13 +484,27 @@ def _run_workers(
493484 raise NotImplementedError (
494485 "max_concurrent_workers is not supported yet." )
495486
487+ count = len (self .workers ) if not \
488+ async_run_tensor_parallel_workers_only \
489+ else len (self .non_driver_workers )
490+
491+ # If using SPMD worker, all workers are the same, so we should execute
492+ # the args on all workers. Otherwise, we skip the first worker's args
493+ # because those args will go to the driver worker.
494+ first_worker_args_index : int = 0 if self .use_ray_spmd_worker else 1
495+ all_worker_args = repeat (args , count ) if all_args is None \
496+ else islice (all_args , first_worker_args_index , None )
497+ all_worker_kwargs = repeat (kwargs , count ) if all_kwargs is None \
498+ else islice (all_kwargs , first_worker_args_index , None )
499+
496500 # Start the ray workers first.
497501 ray_workers = self .workers
498502 if async_run_tensor_parallel_workers_only :
499503 ray_workers = self .non_driver_workers
500504 ray_worker_outputs = [
501- worker .execute_method .remote (sent_method , * args , ** kwargs )
502- for worker in ray_workers
505+ worker .execute_method .remote (method , * worker_args , ** worker_kwargs )
506+ for (worker , worker_args , worker_kwargs
507+ ) in zip (ray_workers , all_worker_args , all_worker_kwargs )
503508 ]
504509
505510 if async_run_tensor_parallel_workers_only :
@@ -511,9 +516,13 @@ def _run_workers(
511516 # so we only explicitly execute on the driver worker if using a
512517 # non-SPMD worker class.
513518 if not self .use_ray_spmd_worker :
519+ driver_args = args if all_args is None else all_args [0 ]
520+ driver_kwargs = kwargs if all_kwargs is None else all_kwargs [0 ]
521+
514522 # Start the driver worker after all the ray workers.
515523 driver_worker_output = [
516- self .driver_worker .execute_method (sent_method , * args , ** kwargs )
524+ self .driver_worker .execute_method (method , * driver_args ,
525+ ** driver_kwargs )
517526 ]
518527
519528 # Get the results of the ray workers.
0 commit comments