Skip to content

Commit

Permalink
[VLM] Report multi_modal_placeholders in output (vllm-project#10407)
Browse files Browse the repository at this point in the history
Signed-off-by: Linkun Chen <lkchen+anyscale@github.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
  • Loading branch information
lk-chen authored and mfournioux committed Nov 20, 2024
1 parent ca3c089 commit dad15c3
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 10 deletions.
79 changes: 78 additions & 1 deletion tests/models/decoder_only/vision_language/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import pytest
from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor

from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
TextPrompt, TokensPrompt)
from vllm.multimodal import MultiModalDataBuiltins
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs

from ....utils import VLLM_PATH, large_gpu_test
Expand Down Expand Up @@ -49,6 +53,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
}]


def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]:
return [{
"role":
"user",
"content": [{
"type": "text",
"content": PROMPT,
}, *({
"type": "image",
"image": download_image(url)
} for url in urls)],
}]


def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
msg = _create_msg_format(urls)

Expand All @@ -70,6 +88,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
return engine_inputs


def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt:
msg = _create_msg_format_hf(urls)

tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
prompt = tokenizer.apply_chat_template(msg)

images = []
for chunk in msg[0]["content"]:
if chunk["type"] == "image":
images.append(chunk["image"])

mm_data = MultiModalDataBuiltins(image=images)
engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)

return engine_inputs


MSGS = [
_create_msg_format(IMG_URLS[:1]),
_create_msg_format(IMG_URLS[:2]),
Expand Down Expand Up @@ -191,3 +226,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")


@large_gpu_test(min_gb=24)
@pytest.mark.parametrize(
"prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
"offset": 10,
"length": 494
}]),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
"offset": 10,
"length": 266
}, {
"offset": 276,
"length": 1056
}, {
"offset": 1332,
"length": 418
}])])
def test_multi_modal_placeholders(
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
with vllm_runner(
"mistral-community/pixtral-12b",
max_model_len=8192,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
outputs = vllm_model.model.generate(prompt)

assert len(outputs) == 1, f"{len(outputs)=}"
output: RequestOutput = outputs[0]
assert hasattr(output,
"multi_modal_placeholders"), f"{output.__dict__=}"
assert "image" in output.multi_modal_placeholders, \
f"{output.multi_modal_placeholders.keys()=}"
image_placeholder_ranges: list[
PlaceholderRange] = output.multi_modal_placeholders["image"]
assert len(image_placeholder_ranges) == len(
expected_ranges), f"{image_placeholder_ranges=}"
for real_range, expected_range in zip(image_placeholder_ranges,
expected_ranges):
assert real_range == expected_range, \
f"{real_range=} {expected_range=}"
16 changes: 15 additions & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData
Expand Down Expand Up @@ -773,15 +774,28 @@ def input_processor_for_pixtral_hf(
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)

reverse_offsets: List[int] = []
# Backward iteration for replacement without affecting known indices
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
reversed(replace_tokens_list)):
reverse_offsets.append(
len(new_token_ids) - placeholder_idx + len(replace_tokens))
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens

placeholder_ranges: List[PlaceholderRange] = []
for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
replace_tokens_list):
placeholder_ranges.append(
PlaceholderRange(
offset=len(new_token_ids) - reverse_offset,
length=len(replace_tokens),
))

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})


class PixtralHFMLP(nn.Module):
Expand Down
30 changes: 22 additions & 8 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
Expand Down Expand Up @@ -103,10 +104,13 @@ def __init__(
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.multi_modal_placeholders = multi_modal_placeholders or {}
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
Expand Down Expand Up @@ -275,17 +279,26 @@ def from_seq_group(
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)

init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids, num_cached_tokens)
init_kwargs = {
"request_id": seq_group.request_id,
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"prompt_logprobs": prompt_logprobs,
"outputs": outputs,
"finished": finished,
"metrics": seq_group.metrics,
"lora_request": seq_group.lora_request,
"encoder_prompt": encoder_prompt,
"encoder_prompt_token_ids": encoder_prompt_token_ids,
"num_cached_tokens": num_cached_tokens,
"multi_modal_placeholders": seq_group.multi_modal_placeholders
}

if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(*init_args) # type: ignore

request_output.__init__(**init_kwargs) # type: ignore
else:
request_output = cls(*init_args)
request_output = cls(**init_kwargs) # type: ignore

return request_output

Expand All @@ -300,7 +313,8 @@ def __repr__(self) -> str:
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")
f"num_cached_tokens={self.num_cached_tokens}, "
f"multi_modal_placeholders={self.multi_modal_placeholders})")


class EmbeddingRequestOutput:
Expand Down

0 comments on commit dad15c3

Please sign in to comment.