Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] Add mypy type annotation part 1 #4006

Merged
merged 11 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: mypy

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main

jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2 mypy==1.9.0
- name: Mypy
run: |
mypy vllm/core/
9 changes: 6 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,12 @@ fi
echo 'vLLM yapf: Done'

# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml

CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ ignore = [
]

[tool.mypy]
python_version = "3.8"
python_version = "3.9"
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

ignore_missing_imports = true
check_untyped_defs = true

files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
"vllm/worker/.*"
]


[tool.codespell]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2

# type checking
mypy==0.991
mypy==1.9.0
types-PyYAML
types-requests
types-setuptools
Expand Down
11 changes: 6 additions & 5 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def __init__(

if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
Expand Down Expand Up @@ -588,7 +588,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
for b in takewhile(lambda b: b.computed, block_table[:-1])
]

def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(self,
seqs: List[Sequence]) -> Sequence[int]:
"""Return the block ids that are common for a given sequence group.

Used in prefill (can skip prefill of some blocks).
Expand Down
25 changes: 15 additions & 10 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0

Expand Down Expand Up @@ -133,7 +133,7 @@ def is_empty(self) -> bool:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self) -> bool:
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
Expand Down Expand Up @@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
self.free_seq(seq)

def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
Expand Down Expand Up @@ -404,7 +405,7 @@ def _schedule_running(
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)
curr_loras.remove(seq_group.lora_int_id)

if running_queue:
# Preempt the lowest-priority sequence groups.
Expand Down Expand Up @@ -496,7 +497,7 @@ def _schedule_swapped(
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)

leftover_swapped = deque()
leftover_swapped: deque = deque()
while swapped_queue:
seq_group = swapped_queue[0]

Expand All @@ -507,7 +508,9 @@ def _schedule_swapped(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
Expand Down Expand Up @@ -593,7 +596,7 @@ def _schedule_prefills(
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])

leftover_waiting_sequences = deque()
leftover_waiting_sequences: deque = deque()
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]

Expand Down Expand Up @@ -635,6 +638,8 @@ def _schedule_prefills(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
Expand Down Expand Up @@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self):
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras = set()
curr_loras: Set[int] = set()

remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
Expand Down Expand Up @@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int:

def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]:
budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.

Expand Down
10 changes: 5 additions & 5 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
Expand All @@ -157,10 +157,10 @@ def broadcast_tensor_dict(

rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
Expand Down Expand Up @@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)

# Streaming case
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def generate(
multi_modal_data.data = multi_modal_data.data.to(torch.float16)

# Add requests to the engine.
assert prompt_token_ids is not None
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
for i in range(num_requests):
Expand Down
5 changes: 3 additions & 2 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _run_workers(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[tuple[Any, ...]] = None,
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
Expand All @@ -291,6 +291,7 @@ def _run_workers(
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
Expand Down Expand Up @@ -369,7 +370,7 @@ async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
Expand Down
3 changes: 1 addition & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable, List, Optional, Union

import torch
from pydantic import conint
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -127,7 +126,7 @@ def __init__(
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down
Loading