Skip to content

Commit

Permalink
[mypy] Misc. typing improvements (#7417)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 13, 2024
1 parent 198d6a2 commit 9ba85bc
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 75 deletions.
16 changes: 12 additions & 4 deletions tests/tensorizer_loader/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import functools
import gc
from typing import Callable, TypeVar

import pytest
import ray
import torch
from typing_extensions import ParamSpec

from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
Expand All @@ -22,20 +24,26 @@ def cleanup():
torch.cuda.empty_cache()


def retry_until_skip(n):
_P = ParamSpec("_P")
_R = TypeVar("_R")

def decorator_retry(func):

def retry_until_skip(n: int):

def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:

@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
pytest.skip(f"Skipping test after {n} attempts.")

raise AssertionError("Code should not be reached")

return wrapper_retry

Expand Down
30 changes: 8 additions & 22 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
import os
import socket
import sys
from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)
from typing import AsyncIterator, Tuple

import pytest

Expand All @@ -13,26 +11,11 @@

from .utils import error_on_warning

if sys.version_info < (3, 10):
if TYPE_CHECKING:
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
_AwaitableT_co = TypeVar("_AwaitableT_co",
bound=Awaitable[Any],
covariant=True)

class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):

def __anext__(self) -> _AwaitableT_co:
...

def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
return i.__anext__()


@pytest.mark.asyncio
async def test_merge_async_iterators():

async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
async def mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
Expand All @@ -41,8 +24,10 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
print(f"iterator {idx} cancelled")

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
merged_iterator = merge_async_iterators(*iterators,
is_cancelled=partial(asyncio.sleep,
0,
result=False))

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
Expand All @@ -56,7 +41,8 @@ async def stream_output(generator: AsyncIterator[Tuple[int, str]]):

for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
# Can use anext() in python >= 3.10
await asyncio.wait_for(iterator.__anext__(), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
Expand Down
11 changes: 8 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import openai
import ray
import requests
from transformers import AutoTokenizer
from typing_extensions import ParamSpec

from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
time.sleep(5)


def fork_new_process_for_each_test(f):
_P = ParamSpec("_P")


def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""

@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
Expand Down
8 changes: 4 additions & 4 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
TypeVar)
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type

from torch import nn
from transformers import PretrainedConfig
from typing_extensions import TypeVar

from vllm.logger import init_logger

Expand All @@ -17,7 +17,7 @@

logger = init_logger(__name__)

C = TypeVar("C", bound=PretrainedConfig)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)


@dataclass(frozen=True)
Expand All @@ -44,7 +44,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":

return multimodal_config

def get_hf_config(self, hf_config_type: Type[C]) -> C:
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,


def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config

use_thumbnail = hf_config.use_thumbnail
Expand All @@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config

image_size = vision_config.image_size
Expand Down Expand Up @@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):

image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers.configuration_utils import PretrainedConfig
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand Down Expand Up @@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:


def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
return getattr(hf_config, "query_num", 64)


Expand All @@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):


def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()

seq_data = dummy_seq_data_for_minicpmv(seq_len)
mm_data = dummy_image_for_minicpmv(hf_config)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
def get_max_phi3v_image_tokens(ctx: InputContext):

return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
ctx.get_hf_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
Expand Down Expand Up @@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
Expand Down
5 changes: 2 additions & 3 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import torch
from PIL import Image
from transformers import PreTrainedTokenizerBase

from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import is_list_of

from .base import MultiModalInputs, MultiModalPlugin
Expand Down Expand Up @@ -40,7 +39,7 @@ def repeat_and_pad_token(


def repeat_and_pad_image_tokens(
tokenizer: PreTrainedTokenizerBase,
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
Expand Down
10 changes: 7 additions & 3 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@

import os
from functools import lru_cache, wraps
from typing import List, Tuple
from typing import Callable, List, Tuple, TypeVar

import pynvml
from typing_extensions import ParamSpec

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum

logger = init_logger(__name__)

_P = ParamSpec("_P")
_R = TypeVar("_R")

# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA


def with_nvml_context(fn):
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:

@wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
Expand Down
18 changes: 8 additions & 10 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Dict, List, Optional, Tuple, Union

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Dict, List, Optional, Tuple

from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)

from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup

# Used eg. for marking rejected tokens in spec decoding.
INVALID_TOKEN_ID = -1
Expand All @@ -16,8 +15,7 @@ class Detokenizer:
def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group

def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

Expand Down Expand Up @@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):


def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
Expand Down Expand Up @@ -213,7 +211,7 @@ def _convert_tokens_to_string_with_added_encoders(


def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
Expand All @@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async

from .tokenizer_group import AnyTokenizer

logger = init_logger(__name__)

AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_tokenizer(


def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
**kwargs) -> Optional[AnyTokenizer]:
if lora_request is None:
return None
try:
Expand Down
3 changes: 1 addition & 2 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from .tokenizer_group import TokenizerGroup

if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from .ray_tokenizer_group import RayTokenizerGroupPool
else:
RayTokenizerGroupPool = None # type: ignore

Expand Down
Loading

0 comments on commit 9ba85bc

Please sign in to comment.