Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/encoder_decoder_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_florence2():
def run_mllama():
engine_args = EngineArgs(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": 1},
dtype="half",
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
)

Expand Down
3 changes: 2 additions & 1 deletion tests/engine/test_short_mm_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets]

with pytest.raises(ValueError, match="too long to fit into the model"):
with pytest.raises(ValueError,
match="longer than the maximum model length"):
vllm_model = vllm_runner(
model,
max_model_len=128, # LLaVA has a feature size of 576
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def v1(run_with_both_engines):

def test_empty_prompt():
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
llm.generate([""])


Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_empty_prompt():
client = remote_server.get_async_client()

with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')):
match="decoder prompt cannot be empty"):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
Expand Down
8 changes: 4 additions & 4 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_bnb_regression(
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
quantization="bitsandbytes",
)
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_explicit_implicit_prompt(
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=1,
)
Expand Down Expand Up @@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=1,
limit_mm_per_prompt={"image":
Expand Down
63 changes: 46 additions & 17 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
Iterable, List, Literal, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload

Expand All @@ -30,7 +30,7 @@
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
PromptType, SingletonInputs)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
Expand All @@ -40,6 +40,7 @@
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -2029,29 +2030,57 @@ def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")

prompt_ids = prompt_inputs["prompt_token_ids"]
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
if prompt_type == "encoder" and self.tokenizer is not None:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
model_config = self.model_config

if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)

if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper

prompt_ids = prompt_inputs["prompt_token_ids"]

if not prompt_ids:
raise ValueError(f"The {prompt_type} prompt cannot be empty")

max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")

raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")

# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
Expand Down
13 changes: 6 additions & 7 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,12 @@ def get_encoder_dummy_data(

total_len = len(encoder_prompt_token_ids)

# Encoder-decoder multimodal models only support v0
if total_len > seq_len:
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The encoder sequence length used for profiling ("
Expand All @@ -229,11 +233,6 @@ def get_encoder_dummy_data(
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.")

processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)

return DummyEncoderData(encoder_prompt_token_ids)

def get_decoder_dummy_data(
Expand Down
76 changes: 49 additions & 27 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import time
from collections.abc import Mapping
from typing import Optional, Union
from typing import Literal, Optional, Union

from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -287,41 +288,62 @@ def _validate_model_inputs(self,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")

prompt_ids = prompt_inputs["prompt_token_ids"]
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)

max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if prompt_type == "encoder":
model_config = self.model_config

if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the "
f"maximum model length of {self.model_config.max_model_len}.")
if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)

if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper

if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
prompt_ids = prompt_inputs["prompt_token_ids"]

if not prompt_ids:
raise ValueError(f"The {prompt_type} prompt cannot be empty")

max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")

max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")

raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")

# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
Expand Down