11import asyncio
2- from typing import Any , List , Optional
2+ from typing import Any , Callable , List , Optional , Union
3+
4+ import cloudpickle
35
46from vllm .executor .executor_base import DistributedExecutorBase
57from vllm .executor .multiproc_worker_utils import (
911from vllm .model_executor .layers .sampler import SamplerOutput
1012from vllm .sequence import ExecuteModelRequest
1113from vllm .utils import (_run_task_with_lock , get_distributed_init_method ,
12- get_ip , get_open_port , make_async )
14+ get_ip , get_open_port , make_async , run_method )
1315from vllm .worker .worker_base import WorkerWrapperBase
1416
1517logger = init_logger (__name__ )
@@ -107,7 +109,7 @@ def _driver_execute_model(
107109
108110 def _run_workers (
109111 self ,
110- method : str ,
112+ method : Union [ str , Callable ] ,
111113 * args ,
112114 async_run_tensor_parallel_workers_only : bool = False ,
113115 max_concurrent_workers : Optional [int ] = None ,
@@ -121,6 +123,11 @@ def _run_workers(
121123 It will also be run asynchronously and return a list of futures
122124 rather than blocking on the results.
123125 """
126+ if isinstance (method , str ):
127+ sent_method = method
128+ else :
129+ sent_method = cloudpickle .dumps (method )
130+ del method
124131
125132 if max_concurrent_workers :
126133 raise NotImplementedError (
@@ -129,18 +136,18 @@ def _run_workers(
129136 if async_run_tensor_parallel_workers_only :
130137 # Run only non-driver workers and just return futures.
131138 return [
132- worker .execute_method (method , * args , ** kwargs )
139+ worker .execute_method (sent_method , * args , ** kwargs )
133140 for worker in self .non_driver_workers
134141 ]
135142
136143 # Start all remote workers first.
137144 worker_outputs = [
138- worker .execute_method (method , * args , ** kwargs )
145+ worker .execute_method (sent_method , * args , ** kwargs )
139146 for worker in self .workers
140147 ]
141148
142- driver_worker_method = getattr (self .driver_worker , method )
143- driver_worker_output = driver_worker_method ( * args , ** kwargs )
149+ driver_worker_output = run_method (self .driver_worker , sent_method ,
150+ args , kwargs )
144151
145152 # Get the results of the workers.
146153 return [driver_worker_output
0 commit comments