Skip to content

Commit 6d917d0

Browse files
authored
Enable mypy checking on V1 code (#11105)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 93abf23 commit 6d917d0

21 files changed

+160
-121
lines changed

tools/mypy.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ run_mypy vllm/plugins
2929
run_mypy vllm/prompt_adapter
3030
run_mypy vllm/spec_decode
3131
run_mypy vllm/worker
32+
run_mypy vllm/v1

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def forward(
135135
assert k_scale == 1.0 and v_scale == 1.0, (
136136
"key/v_scale is not supported in FlashAttention.")
137137

138+
assert output is not None, "Output tensor must be provided."
139+
138140
if attn_metadata is None:
139141
# Profiling run.
140142
return output

vllm/v1/core/kv_cache_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Dict, List, Optional
2+
from typing import Dict, Iterable, List, Optional
33

44
from vllm.logger import init_logger
55
from vllm.utils import cdiv
@@ -263,12 +263,13 @@ def free(self, request: Request) -> None:
263263
"""
264264
# Default to [] in case a request is freed (aborted) before alloc.
265265
blocks = self.req_to_blocks.pop(request.request_id, [])
266+
ordered_blocks: Iterable[KVCacheBlock] = blocks
266267
if self.enable_caching:
267268
# Free blocks in reverse order so that the tail blocks are
268269
# freed first.
269-
blocks = reversed(blocks)
270+
ordered_blocks = reversed(blocks)
270271

271-
for block in blocks:
272+
for block in ordered_blocks:
272273
block.decr_ref()
273274
if block.ref_cnt == 0:
274275
self.free_block_queue.append(block)
@@ -396,8 +397,7 @@ def _cache_full_blocks(
396397
f"{request.request_id}({request})")
397398

398399
# Compute the hash of the current block.
399-
block_hash = hash_block_tokens(prev_block_hash_value,
400-
tuple(block_tokens))
400+
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
401401

402402
# Update and added the full block to the cache.
403403
blk.block_hash = block_hash

vllm/v1/core/kv_cache_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""KV-Cache Utilities."""
2+
from collections.abc import Sequence
23
from dataclasses import dataclass
34
from typing import List, NamedTuple, Optional, Tuple
45

@@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
1314
collision happens when the hash value is the same.
1415
"""
1516
hash_value: int
16-
token_ids: Tuple[int]
17+
token_ids: Tuple[int, ...]
1718

1819

1920
@dataclass
@@ -79,8 +80,8 @@ def __init__(self, blocks: List[KVCacheBlock]) -> None:
7980
self.num_free_blocks = len(blocks)
8081

8182
# Initialize the doubly linked list of free blocks.
82-
self.free_list_head = blocks[0]
83-
self.free_list_tail = blocks[-1]
83+
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
84+
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
8485
for i in range(self.num_free_blocks):
8586
if i > 0:
8687
blocks[i].prev_free_block = blocks[i - 1]
@@ -159,7 +160,7 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:
159160

160161

161162
def hash_block_tokens(parent_block_hash: Optional[int],
162-
curr_block_token_ids: Tuple[int]) -> BlockHashType:
163+
curr_block_token_ids: Sequence[int]) -> BlockHashType:
163164
"""Computes a hash value corresponding to the contents of a block and
164165
the contents of the preceding block(s). The hash value is used for
165166
prefix caching. We use LRU cache for this function to avoid recomputing
@@ -171,19 +172,19 @@ def hash_block_tokens(parent_block_hash: Optional[int],
171172
Args:
172173
parent_block_hash: The hash of the parent block. None
173174
if this is the first block.
174-
curr_block_token_ids: A tuple of token ids in the current
175+
curr_block_token_ids: A list of token ids in the current
175176
block. The current block is assumed to be full.
176177
177178
Returns:
178179
The hash value of the block and the token ids in the block.
179180
The entire tuple is used as the hash key of the block.
180181
"""
181182
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
182-
curr_block_token_ids)
183+
tuple(curr_block_token_ids))
183184

184185

185186
def hash_request_tokens(block_size: int,
186-
token_ids: List[int]) -> List[BlockHashType]:
187+
token_ids: Sequence[int]) -> List[BlockHashType]:
187188
"""Computes hash values of a chain of blocks given a sequence of
188189
token IDs. The hash value is used for prefix caching.
189190
@@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
198199
parent_block_hash_value = None
199200
for start in range(0, len(token_ids), block_size):
200201
end = start + block_size
201-
block_token_ids = tuple(token_ids[start:end])
202+
block_token_ids = token_ids[start:end]
202203
# Do not hash the block if it is not full.
203204
if len(block_token_ids) < block_size:
204205
break

vllm/v1/core/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def schedule(self) -> "SchedulerOutput":
152152
break
153153
if not can_schedule:
154154
break
155+
assert new_blocks is not None
155156

156157
# Schedule the request.
157158
scheduled_running_reqs.append(request)

vllm/v1/engine/__init__.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,19 @@ class EngineCoreRequest:
3636
prompt: Optional[str]
3737
prompt_token_ids: List[int]
3838
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
39-
mm_hashes: Optional[List[Optional[str]]]
39+
mm_hashes: Optional[List[str]]
4040
mm_placeholders: Optional[MultiModalPlaceholderDict]
4141
sampling_params: SamplingParams
4242
eos_token_id: Optional[int]
4343
arrival_time: float
4444
lora_request: Optional[LoRARequest]
4545

4646

47-
class EngineCoreOutput(msgspec.Struct,
48-
array_like=True,
49-
omit_defaults=True,
50-
gc=False):
47+
class EngineCoreOutput(
48+
msgspec.Struct,
49+
array_like=True, # type: ignore[call-arg]
50+
omit_defaults=True, # type: ignore[call-arg]
51+
gc=False): # type: ignore[call-arg]
5152

5253
request_id: str
5354
new_token_ids: List[int]
@@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
5657
stop_reason: Union[int, str, None] = None
5758

5859

59-
class EngineCoreOutputs(msgspec.Struct,
60-
array_like=True,
61-
omit_defaults=True,
62-
gc=False):
60+
class EngineCoreOutputs(
61+
msgspec.Struct,
62+
array_like=True, # type: ignore[call-arg]
63+
omit_defaults=True, # type: ignore[call-arg]
64+
gc=False): # type: ignore[call-arg]
6365

6466
#NOTE(Nick): We could consider ways to make this more compact,
6567
# e.g. columnwise layout and using an int enum for finish/stop reason
@@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
8183
ADD = b'\x00'
8284
ABORT = b'\x01'
8385
PROFILE = b'\x02'
86+
87+
88+
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]

vllm/v1/engine/async_llm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
asyncio_mode=True,
8282
)
8383

84-
self.output_handler = None
84+
self.output_handler: Optional[asyncio.Task] = None
8585

8686
def __del__(self):
8787
self.shutdown()
@@ -126,7 +126,8 @@ def shutdown(self):
126126
handler.cancel()
127127

128128
@classmethod
129-
def _get_executor_cls(cls, vllm_config: VllmConfig):
129+
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
130+
executor_class: Type[Executor]
130131
distributed_executor_backend = (
131132
vllm_config.parallel_config.distributed_executor_backend)
132133
if distributed_executor_backend == "mp":
@@ -361,10 +362,10 @@ async def check_health(self) -> None:
361362
logger.debug("Called check_health.")
362363

363364
async def start_profile(self) -> None:
364-
await self.engine_core.profile(True)
365+
await self.engine_core.profile_async(True)
365366

366367
async def stop_profile(self) -> None:
367-
await self.engine_core.profile(False)
368+
await self.engine_core.profile_async(False)
368369

369370
@property
370371
def is_running(self) -> bool:
@@ -380,7 +381,7 @@ def errored(self) -> bool:
380381

381382
@property
382383
def dead_error(self) -> BaseException:
383-
return Exception
384+
return Exception() # TODO: implement
384385

385386

386387
# Retain V0 name for backwards compatibility.

vllm/v1/engine/core.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from dataclasses import dataclass
77
from multiprocessing.process import BaseProcess
8-
from typing import List, Tuple, Type, Union
8+
from typing import List, Tuple, Type
99

1010
import zmq
1111
import zmq.asyncio
@@ -20,7 +20,7 @@
2020
from vllm.v1.core.scheduler import Scheduler
2121
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
2222
EngineCoreProfile, EngineCoreRequest,
23-
EngineCoreRequestType)
23+
EngineCoreRequestType, EngineCoreRequestUnion)
2424
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
2525
from vllm.v1.executor.abstract import Executor
2626
from vllm.v1.request import Request, RequestStatus
@@ -97,8 +97,10 @@ def add_request(self, request: EngineCoreRequest):
9797
# Note that the cache here is mirrored with the client side of the
9898
# MM mapper, so anything that has a hash must have a HIT cache
9999
# entry here as well.
100-
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
101-
request.mm_inputs, request.mm_hashes)
100+
assert request.mm_inputs is not None
101+
request.mm_inputs, request.mm_hashes = (
102+
self.mm_input_mapper_server.process_inputs(
103+
request.mm_inputs, request.mm_hashes))
102104

103105
req = Request.from_engine_core_request(request)
104106

@@ -128,7 +130,7 @@ def step(self) -> List[EngineCoreOutput]:
128130
def shutdown(self):
129131
self.model_executor.shutdown()
130132

131-
def profile(self, is_start=True):
133+
def profile(self, is_start: bool = True):
132134
self.model_executor.profile(is_start)
133135

134136

@@ -161,8 +163,8 @@ def __init__(
161163
# and to overlap some serialization/deserialization with the
162164
# model forward pass.
163165
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
164-
self.input_queue = queue.Queue()
165-
self.output_queue = queue.Queue()
166+
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
167+
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
166168
threading.Thread(target=self.process_input_socket,
167169
args=(input_path, ),
168170
daemon=True).start()
@@ -318,9 +320,7 @@ def _log_stats(self):
318320

319321
self._last_logging_time = now
320322

321-
def _handle_client_request(
322-
self, request: Union[EngineCoreRequest, EngineCoreProfile,
323-
List[str]]) -> None:
323+
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
324324
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
325325

326326
if isinstance(request, EngineCoreRequest):

vllm/v1/engine/core_client.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import atexit
22
import os
3-
from typing import List, Union
3+
from typing import List, Optional
44

55
import msgspec
66
import zmq
@@ -10,8 +10,9 @@
1010
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
1111
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
1212
EngineCoreProfile, EngineCoreRequest,
13-
EngineCoreRequestType)
14-
from vllm.v1.engine.core import EngineCore, EngineCoreProc
13+
EngineCoreRequestType, EngineCoreRequestUnion)
14+
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
15+
EngineCoreProcHandle)
1516
from vllm.v1.serial_utils import PickleEncoder
1617

1718
logger = init_logger(__name__)
@@ -59,7 +60,7 @@ def get_output(self) -> List[EngineCoreOutput]:
5960
def add_request(self, request: EngineCoreRequest) -> None:
6061
raise NotImplementedError
6162

62-
async def profile(self, is_start=True) -> None:
63+
def profile(self, is_start: bool = True) -> None:
6364
raise NotImplementedError
6465

6566
def abort_requests(self, request_ids: List[str]) -> None:
@@ -71,6 +72,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
7172
async def add_request_async(self, request: EngineCoreRequest) -> None:
7273
raise NotImplementedError
7374

75+
async def profile_async(self, is_start: bool = True) -> None:
76+
raise NotImplementedError
77+
7478
async def abort_requests_async(self, request_ids: List[str]) -> None:
7579
raise NotImplementedError
7680

@@ -105,7 +109,7 @@ def shutdown(self):
105109
def __del__(self):
106110
self.shutdown()
107111

108-
def profile(self, is_start=True) -> None:
112+
def profile(self, is_start: bool = True) -> None:
109113
self.engine_core.profile(is_start)
110114

111115

@@ -133,7 +137,10 @@ def __init__(
133137
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
134138

135139
# ZMQ setup.
136-
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
140+
if asyncio_mode:
141+
self.ctx = zmq.asyncio.Context()
142+
else:
143+
self.ctx = zmq.Context() # type: ignore[attr-defined]
137144

138145
# Path for IPC.
139146
ready_path = get_open_zmq_ipc_path()
@@ -149,11 +156,13 @@ def __init__(
149156
self.input_socket.bind(input_path)
150157

151158
# Start EngineCore in background process.
159+
self.proc_handle: Optional[EngineCoreProcHandle]
152160
self.proc_handle = EngineCoreProc.make_engine_core_process(
153161
*args,
154-
input_path=input_path,
155-
output_path=output_path,
156-
ready_path=ready_path,
162+
input_path=
163+
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
164+
output_path=output_path, # type: ignore[misc]
165+
ready_path=ready_path, # type: ignore[misc]
157166
**kwargs,
158167
)
159168
atexit.register(self.shutdown)
@@ -204,10 +213,8 @@ def get_output(self) -> List[EngineCoreOutput]:
204213
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
205214
return engine_core_outputs
206215

207-
def _send_input(
208-
self, request_type: EngineCoreRequestType,
209-
request: Union[EngineCoreRequest, EngineCoreProfile,
210-
List[str]]) -> None:
216+
def _send_input(self, request_type: EngineCoreRequestType,
217+
request: EngineCoreRequestUnion) -> None:
211218

212219
# (RequestType, SerializedRequest)
213220
msg = (request_type.value, self.encoder.encode(request))
@@ -219,7 +226,7 @@ def add_request(self, request: EngineCoreRequest) -> None:
219226
def abort_requests(self, request_ids: List[str]) -> None:
220227
self._send_input(EngineCoreRequestType.ABORT, request_ids)
221228

222-
def profile(self, is_start=True) -> None:
229+
def profile(self, is_start: bool = True) -> None:
223230
self._send_input(EngineCoreRequestType.PROFILE,
224231
EngineCoreProfile(is_start))
225232

@@ -237,10 +244,8 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
237244

238245
return engine_core_outputs
239246

240-
async def _send_input(
241-
self, request_type: EngineCoreRequestType,
242-
request: Union[EngineCoreRequest, EngineCoreProfile,
243-
List[str]]) -> None:
247+
async def _send_input(self, request_type: EngineCoreRequestType,
248+
request: EngineCoreRequestUnion) -> None:
244249

245250
msg = (request_type.value, self.encoder.encode(request))
246251
await self.input_socket.send_multipart(msg, copy=False)
@@ -252,6 +257,6 @@ async def abort_requests_async(self, request_ids: List[str]) -> None:
252257
if len(request_ids) > 0:
253258
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
254259

255-
async def profile(self, is_start=True) -> None:
260+
async def profile_async(self, is_start: bool = True) -> None:
256261
await self._send_input(EngineCoreRequestType.PROFILE,
257262
EngineCoreProfile(is_start))

0 commit comments

Comments
 (0)