Skip to content

Commit 9e4af2f

Browse files
yewentao256hmellor
authored andcommitted
[CI] Fix mypy for vllm/distributed (vllm-project#26593)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent f60d915 commit 9e4af2f

File tree

14 files changed

+122
-65
lines changed

14 files changed

+122
-65
lines changed

tools/pre_commit/mypy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
FILES = [
2727
"vllm/*.py",
2828
"vllm/assets",
29+
"vllm/distributed",
2930
"vllm/entrypoints",
3031
"vllm/inputs",
3132
"vllm/logging_utils",
@@ -42,7 +43,6 @@
4243
"tests",
4344
"vllm/attention",
4445
"vllm/compilation",
45-
"vllm/distributed",
4646
"vllm/engine",
4747
"vllm/executor",
4848
"vllm/inputs",

vllm/config/kv_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class KVTransferConfig:
2727
engine_id: str | None = None
2828
"""The engine id for KV transfers."""
2929

30-
kv_buffer_device: str | None = "cuda"
30+
kv_buffer_device: str = "cuda"
3131
"""The device used by kv connector to buffer the KV cache. Choices are
3232
'cuda' and 'cpu'."""
3333

vllm/distributed/device_communicators/all2all.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from .base_device_communicator import All2AllManagerBase, Cache
1616

1717
if has_flashinfer_all2all():
18-
from flashinfer.comm import Mapping
19-
from flashinfer.comm.mnnvl import MnnvlConfig
20-
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
18+
from flashinfer.comm import Mapping # type: ignore[import-not-found]
19+
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
20+
from flashinfer.comm.trtllm_alltoall import (
21+
MnnvlMoe, # type: ignore[import-not-found]
22+
)
2123

2224
logger = init_logger(__name__)
2325

@@ -65,6 +67,7 @@ def dispatch(
6567
) -> tuple[torch.Tensor, torch.Tensor]:
6668
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
6769
dp_metadata = get_forward_context().dp_metadata
70+
assert dp_metadata is not None
6871
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
6972

7073
hidden_states = self.naive_multicast(
@@ -81,6 +84,7 @@ def combine(
8184
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
8285

8386
dp_metadata = get_forward_context().dp_metadata
87+
assert dp_metadata is not None
8488
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
8589
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
8690

@@ -113,7 +117,10 @@ def dispatch(
113117
"""
114118
Gather hidden_states and router_logits from all dp ranks.
115119
"""
116-
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
120+
dp_metadata = get_forward_context().dp_metadata
121+
assert dp_metadata is not None
122+
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
123+
assert sizes is not None
117124

118125
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
119126
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@@ -130,7 +137,10 @@ def combine(
130137
"""
131138
Reduce-scatter hidden_states across all dp ranks.
132139
"""
133-
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
140+
dp_metadata = get_forward_context().dp_metadata
141+
assert dp_metadata is not None
142+
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
143+
assert sizes is not None
134144

135145
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
136146
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
@@ -155,7 +165,7 @@ def __init__(self, cpu_group):
155165
if self.internode:
156166
# inter-node communication needs nvshmem,
157167
# intra-node communication uses p2p mapping directly
158-
from pplx_kernels.nvshmem import (
168+
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
159169
nvshmem_alloc_empty_unique_id,
160170
nvshmem_get_unique_id,
161171
nvshmem_init,
@@ -182,7 +192,7 @@ def __init__(self, cpu_group):
182192
self.handle_cache = Cache()
183193

184194
def get_handle(self, kwargs):
185-
import pplx_kernels as pplx
195+
import pplx_kernels as pplx # type: ignore[import-not-found]
186196

187197
return self.handle_cache.get_or_create(
188198
kwargs,
@@ -208,7 +218,9 @@ def destroy(self):
208218
handle.destroy()
209219

210220
if self.internode:
211-
from pplx_kernels.nvshmem import nvshmem_finalize
221+
from pplx_kernels.nvshmem import (
222+
nvshmem_finalize, # type: ignore[import-not-found]
223+
)
212224

213225
logger.debug("PPLX NVSHMEM finalize")
214226
nvshmem_finalize()
@@ -288,7 +300,7 @@ def get_handle(self, kwargs):
288300
"args are computed in the Manager itself."
289301
)
290302

291-
import deep_ep
303+
import deep_ep # type: ignore[import-not-found]
292304

293305
buffer_kwargs = self._make_all2all_kwargs()
294306
logger.debug("DeepEP all2all args %s", buffer_kwargs)
@@ -298,7 +310,7 @@ def get_handle(self, kwargs):
298310
return handle
299311

300312
def set_num_sms(self, num_sms: int):
301-
import deep_ep
313+
import deep_ep # type: ignore[import-not-found]
302314

303315
# Right now the buffers are sized for only what the kernels were
304316
# created with. So we can only reduce the number of SMS used
@@ -332,7 +344,7 @@ def _make_all2all_kwargs(
332344
num_global_experts: Number of experts in the model.
333345
num_local_experts: Number of experts in an EP rank.
334346
"""
335-
import deep_ep
347+
import deep_ep # type: ignore[import-not-found]
336348

337349
# Defaults for internode and intranode are taken from DeepEP tests.
338350
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
@@ -358,7 +370,7 @@ def get_handle(self, kwargs):
358370
The kwargs for DeepEPLLAll2AllManager is dictated by
359371
_make_all2all_kwargs.
360372
"""
361-
import deep_ep
373+
import deep_ep # type: ignore[import-not-found]
362374

363375
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
364376
logger.debug("DeepEP all2all args %s", buffer_kwargs)

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from contextlib import contextmanager
5+
from typing import cast
56

67
import torch
78
import torch.distributed as dist
@@ -118,15 +119,18 @@ def __init__(
118119
# now `device` is a `torch.device` object
119120
assert isinstance(device, torch.device)
120121
self.device = device
121-
device_capability = current_platform.get_device_capability().as_version_str()
122+
device_capability = current_platform.get_device_capability()
122123
if (
123124
current_platform.is_cuda()
124125
and symm_mem_enabled
125-
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
126+
and device_capability is not None
126127
):
127-
max_size = min(
128-
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
129-
)
128+
device_capability_str = device_capability.as_version_str()
129+
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
130+
max_size = min(
131+
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
132+
max_size,
133+
)
130134
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
131135
if cuda_visible_devices:
132136
device_ids = list(map(int, cuda_visible_devices.split(",")))
@@ -213,6 +217,7 @@ def register_graph_buffers(self):
213217
# We cannot directly use `dist.all_gather_object` here
214218
# because it is incompatible with `gloo` backend under inference mode.
215219
# see https://github.com/pytorch/pytorch/issues/126032 for details.
220+
all_data: list[list[list[int] | None]]
216221
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
217222
all_data[self.rank] = [handle, offset]
218223
ranks = sorted(dist.get_process_group_ranks(group=self.group))
@@ -221,8 +226,8 @@ def register_graph_buffers(self):
221226
all_data[i], src=rank, group=self.group, device="cpu"
222227
)
223228
# Unpack list of tuples to tuple of lists.
224-
handles = [d[0] for d in all_data] # type: ignore
225-
offsets = [d[1] for d in all_data] # type: ignore
229+
handles = cast(list[list[int]], [d[0] for d in all_data])
230+
offsets = cast(list[list[int]], [d[1] for d in all_data])
226231
ops.register_graph_buffers(self._ptr, handles, offsets)
227232

228233
def should_custom_ar(self, inp: torch.Tensor):

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,14 @@ def __init__(
5252
self.device = device
5353
self.group = group
5454
self.world_size = dist.get_world_size(self.group)
55-
self.device_capability = (
56-
current_platform.get_device_capability().as_version_str()
57-
)
55+
capability = current_platform.get_device_capability()
56+
if capability is None:
57+
logger.warning(
58+
"SymmMemCommunicator: device capability is unknown, "
59+
"communicator is not available."
60+
)
61+
return
62+
self.device_capability = capability.as_version_str()
5863
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
5964
logger.warning(
6065
"SymmMemCommunicator: Device capability %s not supported, "

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import importlib
55
from collections.abc import Callable
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, cast
77

88
import vllm.envs as envs
99
from vllm.distributed.kv_transfer.kv_connector.base import (
@@ -48,6 +48,8 @@ def create_connector(
4848
)
4949

5050
kv_transfer_config = config.kv_transfer_config
51+
if kv_transfer_config is None:
52+
raise ValueError("kv_transfer_config must be set to create a connector")
5153
connector_cls = cls.get_connector_class(kv_transfer_config)
5254
logger.info(
5355
"Creating v1 connector with name: %s and engine_id: %s",
@@ -70,14 +72,22 @@ def get_connector_class(
7072
) -> type[KVConnectorBaseType]:
7173
"""Get the connector class by name."""
7274
connector_name = kv_transfer_config.kv_connector
75+
if connector_name is None:
76+
raise ValueError("Connector name is not set in KVTransferConfig")
7377
if connector_name in cls._registry:
7478
connector_cls = cls._registry[connector_name]()
7579
else:
7680
connector_module_path = kv_transfer_config.kv_connector_module_path
7781
if connector_module_path is None:
7882
raise ValueError(f"Unsupported connector type: {connector_name}")
7983
connector_module = importlib.import_module(connector_module_path)
80-
connector_cls = getattr(connector_module, connector_name)
84+
try:
85+
connector_cls = getattr(connector_module, connector_name)
86+
except AttributeError as e:
87+
raise AttributeError(
88+
f"Class {connector_name} not found in {connector_module_path}"
89+
) from e
90+
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
8191
return connector_cls
8292

8393

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,21 @@ def update_finished_set(
151151
aggregated_kv_connector_stats = None
152152
invalid_block_ids = set[int]()
153153
for model_runner_output in outputs:
154-
output = model_runner_output.kv_connector_output
155-
if not output:
154+
kv_output = model_runner_output.kv_connector_output
155+
if not kv_output:
156156
continue
157157
update_finished_set(
158-
output.finished_sending, self._send_remaining_count, finished_sending
158+
kv_output.finished_sending, self._send_remaining_count, finished_sending
159159
)
160160
update_finished_set(
161-
output.finished_recving, self._recv_remaining_count, finished_recving
161+
kv_output.finished_recving, self._recv_remaining_count, finished_recving
162162
)
163163

164164
# Aggregate kv_connector_stats from all workers.
165165
if aggregated_kv_connector_stats is None:
166166
# Use the first worker's kv_connector_stats as accumulator.
167-
aggregated_kv_connector_stats = output.kv_connector_stats
168-
elif kv_connector_stats := output.kv_connector_stats:
167+
aggregated_kv_connector_stats = kv_output.kv_connector_stats
168+
elif kv_connector_stats := kv_output.kv_connector_stats:
169169
if aggregated_kv_connector_stats is None:
170170
aggregated_kv_connector_stats = kv_connector_stats
171171
else:
@@ -176,7 +176,7 @@ def update_finished_set(
176176
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
177177
)
178178

179-
invalid_block_ids |= output.invalid_block_ids
179+
invalid_block_ids |= kv_output.invalid_block_ids
180180

181181
# select output of the worker specified by output_rank
182182
output = outputs[output_rank]

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9595
)
9696
self._connector_metadata: KVConnectorMetadata | None = None
9797
self._vllm_config = vllm_config
98+
if vllm_config.kv_transfer_config is not None:
99+
self._kv_transfer_config = vllm_config.kv_transfer_config
100+
else:
101+
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
98102
self._role = role
99103

100104
@property

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
8686
super().__init__(vllm_config=vllm_config, role=role)
8787
self._connectors: list[KVConnectorBase_V1] = []
8888
self._ktc_kv_transfer_config = []
89-
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
90-
"connectors"
91-
)
89+
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
9290
assert ktcs is not None
9391
for ktc in ktcs:
9492
temp_config = copy.copy(vllm_config)
95-
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
93+
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
9694
temp_config.kv_transfer_config = KVTransferConfig(
9795
**ktc, engine_id=engine_id
9896
)
@@ -296,6 +294,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
296294
str: the required KV cache layout. e.g. HND, or NHD.
297295
None if the connector does not require a specific layout.
298296
"""
297+
assert vllm_config.kv_transfer_config is not None
299298
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
300299
"connectors"
301300
)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
297297
+ vllm_config.parallel_config.data_parallel_rank
298298
* vllm_config.parallel_config.tensor_parallel_size
299299
)
300+
assert vllm_config.kv_transfer_config is not None
300301
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
301302
logger.info("Initializing NIXL Scheduler %s", engine_id)
302303

@@ -340,7 +341,8 @@ def get_num_new_matched_tokens(
340341

341342
if params is not None and params.get("do_remote_prefill"):
342343
# Remote prefill: get all prompt blocks from remote.
343-
count = len(request.prompt_token_ids) - num_computed_tokens
344+
token_ids = request.prompt_token_ids or []
345+
count = len(token_ids) - num_computed_tokens
344346
if count > 0:
345347
return count, True
346348

@@ -521,6 +523,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
521523
self.vllm_config = vllm_config
522524
self.block_size = vllm_config.cache_config.block_size
523525

526+
if vllm_config.kv_transfer_config is None:
527+
raise ValueError("kv_transfer_config must be set for NixlConnector")
528+
524529
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
525530
"backends", ["UCX"]
526531
)
@@ -577,17 +582,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
577582
self.use_host_buffer = self.kv_buffer_device == "cpu"
578583
# support for oot platform which can't register nixl memory
579584
# type based on kv_buffer_device
580-
self.nixl_memory_type = current_platform.get_nixl_memory_type()
581-
if self.nixl_memory_type is None:
585+
nixl_memory_type = current_platform.get_nixl_memory_type()
586+
if nixl_memory_type is None:
582587
if self.kv_buffer_device == "cuda":
583-
self.nixl_memory_type = "VRAM"
588+
nixl_memory_type = "VRAM"
584589
elif self.kv_buffer_device == "cpu":
585-
self.nixl_memory_type = "DRAM"
586-
if self.nixl_memory_type is None:
590+
nixl_memory_type = "DRAM"
591+
if nixl_memory_type is None:
587592
raise RuntimeError(
588593
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
589594
"is not supported."
590595
)
596+
self.nixl_memory_type = nixl_memory_type
591597

592598
# Note: host xfer buffer ops when use_host_buffer is True
593599
self.copy_blocks: CopyBlocksOp | None = None

0 commit comments

Comments
 (0)