Skip to content

Commit 31f027c

Browse files
yewentao256albertoperdomo2
authored andcommitted
[CI] Fix mypy for vllm/executor (vllm-project#26845)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent b14a212 commit 31f027c

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

tools/pre_commit/mypy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"vllm/assets",
2929
"vllm/distributed",
3030
"vllm/entrypoints",
31+
"vllm/executor",
3132
"vllm/inputs",
3233
"vllm/logging_utils",
3334
"vllm/multimodal",
@@ -44,7 +45,6 @@
4445
"vllm/attention",
4546
"vllm/compilation",
4647
"vllm/engine",
47-
"vllm/executor",
4848
"vllm/inputs",
4949
"vllm/lora",
5050
"vllm/model_executor",

vllm/executor/executor_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.sequence import ExecuteModelRequest
1919
from vllm.tasks import SupportedTask
2020
from vllm.utils import make_async
21-
from vllm.v1.outputs import PoolerOutput, SamplerOutput
21+
from vllm.v1.outputs import SamplerOutput
2222
from vllm.v1.worker.worker_base import WorkerBase
2323

2424
logger = init_logger(__name__)
@@ -54,7 +54,7 @@ def __init__(
5454
self._init_executor()
5555
self.is_sleeping = False
5656
self.sleeping_tags: set[str] = set()
57-
self.kv_output_aggregator = None
57+
self.kv_output_aggregator: KVOutputAggregator | None = None
5858

5959
@abstractmethod
6060
def _init_executor(self) -> None:
@@ -143,8 +143,9 @@ def supported_tasks(self) -> tuple[SupportedTask, ...]:
143143

144144
def execute_model(
145145
self, execute_model_req: ExecuteModelRequest
146-
) -> list[SamplerOutput | PoolerOutput] | None:
146+
) -> list[SamplerOutput]:
147147
output = self.collective_rpc("execute_model", args=(execute_model_req,))
148+
assert output[0] is not None
148149
return output[0]
149150

150151
def stop_remote_worker_execution_loop(self) -> None:

vllm/executor/ray_distributed_executor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar
217217
num_gpus=num_gpus,
218218
scheduling_strategy=scheduling_strategy,
219219
**ray_remote_kwargs,
220-
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank)
220+
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
221+
vllm_config=self.vllm_config, rpc_rank=rank
222+
)
221223
else:
222224
worker = ray.remote(
223225
num_cpus=0,
224226
num_gpus=0,
225227
resources={current_platform.ray_device_key: num_gpus},
226228
scheduling_strategy=scheduling_strategy,
227229
**ray_remote_kwargs,
228-
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank)
230+
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
231+
vllm_config=self.vllm_config, rpc_rank=rank
232+
)
229233
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
230234

231235
worker_ips = ray.get(
@@ -303,7 +307,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
303307
continue
304308
worker_node_and_gpu_ids.append(
305309
ray.get(worker.get_node_and_gpu_ids.remote())
306-
) # type: ignore
310+
) # type: ignore[attr-defined]
307311

308312
node_workers = defaultdict(list) # node id -> list of worker ranks
309313
node_gpus = defaultdict(list) # node id -> list of gpu ids
@@ -495,7 +499,9 @@ def _run_workers(
495499
if async_run_tensor_parallel_workers_only:
496500
ray_workers = self.non_driver_workers
497501
ray_worker_outputs = [
498-
worker.execute_method.remote(sent_method, *args, **kwargs)
502+
worker.execute_method.remote( # type: ignore[attr-defined]
503+
sent_method, *args, **kwargs
504+
)
499505
for worker in ray_workers
500506
]
501507

@@ -715,7 +721,7 @@ async def _driver_execute_model_async(
715721
tasks.append(
716722
asyncio.create_task(
717723
_run_task_with_lock(
718-
driver_worker.execute_method.remote,
724+
driver_worker.execute_method.remote, # type: ignore[attr-defined]
719725
self.pp_locks[pp_rank],
720726
"execute_model",
721727
execute_model_req,
@@ -733,7 +739,7 @@ async def _start_worker_execution_loop(self):
733739
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
734740
)
735741
coros = [
736-
worker.execute_method.remote("start_worker_execution_loop")
742+
worker.execute_method.remote("start_worker_execution_loop") # type: ignore[attr-defined]
737743
for worker in self.non_driver_workers
738744
]
739745
return await asyncio.gather(*coros)

vllm/executor/ray_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,17 @@ def execute_model_spmd(
9090

9191
execute_model_req = self.input_decoder.decode(serialized_req)
9292

93+
assert self.worker is not None, "Worker is not initialized"
94+
9395
# TODO(swang): This is needed right now because Ray Compiled Graph
9496
# executes on a background thread, so we need to reset torch's
9597
# current device.
9698
if not self.compiled_dag_cuda_device_set:
99+
assert self.worker.device is not None
97100
current_platform.set_device(self.worker.device)
98101
self.compiled_dag_cuda_device_set = True
99102

100-
output = self.worker._execute_model_spmd(
103+
output = self.worker._execute_model_spmd( # type: ignore[attr-defined]
101104
execute_model_req, intermediate_tensors
102105
)
103106
# Pipeline model request and output to the next pipeline stage.
@@ -119,6 +122,7 @@ def setup_device_if_necessary(self):
119122
# Not needed
120123
pass
121124
else:
125+
assert self.worker.device is not None
122126
current_platform.set_device(self.worker.device)
123127

124128
self.compiled_dag_cuda_device_set = True
@@ -139,6 +143,7 @@ def execute_model_ray(
139143
scheduler_output, intermediate_tensors = scheduler_output
140144
else:
141145
scheduler_output, intermediate_tensors = scheduler_output, None
146+
assert self.worker.model_runner is not None
142147
output = self.worker.model_runner.execute_model(
143148
scheduler_output, intermediate_tensors
144149
)

0 commit comments

Comments
 (0)