Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
(None, dict()),
("image=16", {
"image": 16
}),
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import json

import openai
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -27,7 +29,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"audio": MAXIMUM_AUDIOS}),
json.dumps({"audio": MAXIMUM_AUDIOS}),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_video.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import json

import openai
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -31,7 +33,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"video": MAXIMUM_VIDEOS}),
json.dumps({"video": MAXIMUM_VIDEOS}),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import json

import openai
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -35,7 +37,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}),
json.dumps({"image": MAXIMUM_IMAGES}),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import json

import pytest
import requests
from PIL import Image
Expand Down Expand Up @@ -37,7 +39,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}),
json.dumps({"image": MAXIMUM_IMAGES}),
"--chat-template",
str(vlm2vec_jinja_path),
]
Expand Down
3 changes: 2 additions & 1 deletion tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -50,7 +51,7 @@ def server(request, audio_assets):
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
str({"audio": len(audio_assets)}), "--trust-remote-code"
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
Expand Down
7 changes: 3 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import textwrap
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
Expand Down Expand Up @@ -355,7 +354,7 @@ def __init__(
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
Expand Down Expand Up @@ -578,7 +577,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
self.tokenizer = s3_tokenizer.dir

def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
self, limit_mm_per_prompt: Optional[dict[str, int]]
) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
Expand Down Expand Up @@ -2730,7 +2729,7 @@ def verify_with_model_config(self, model_config: ModelConfig):
class MultiModalConfig:
"""Controls the behavior of multimodal models."""

limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
limit_per_prompt: dict[str, int] = field(default_factory=dict)
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Expand Down
18 changes: 9 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin)

Expand Down Expand Up @@ -112,14 +112,14 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:


def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
if re.match("^{.*}$", val):
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)

logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)


@dataclass
Expand Down Expand Up @@ -191,7 +191,7 @@ class EngineArgs:
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Mapping[str, int] = \
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
Expand Down