Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]

EXPECTED_MM_BEAM_SEARCH_RES = [
[
"The image shows a wooden boardwalk leading through a",
"The image shows a wooden boardwalk extending into a",
],
[
"The image shows two parrots perched on",
"The image shows two birds perched on a cur",
],
[
"The image shows a Venn diagram with three over",
"This image shows a Venn diagram with three over",
],
[
"This image displays a gradient of colors ranging from",
"This image displays a gradient of colors transitioning from",
],
]


@pytest.fixture(scope="module")
def server():
Expand Down Expand Up @@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
base64_encoded_image: dict[str, str]):
# NOTE: This test also validates that we pass MM data through beam search
image_url = TEST_IMAGE_URLS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]

messages = [{
"role":
Expand All @@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
messages=messages,
n=2,
max_completion_tokens=10,
temperature=0.0,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
for actual, expected_str in zip(chat_completion.choices, expected_res):
assert actual.message.content == expected_str


@pytest.mark.asyncio
Expand Down
13 changes: 11 additions & 2 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,18 @@ async def beam_search(
if processed_inputs["type"] == "embeds":
raise NotImplementedError

prompt_token_ids = processed_inputs["prompt_token_ids"]
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_token_ids = prompt.get("prompt_token_ids")
multi_modal_data = prompt.get("multi_modal_data")

prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

tokenized_length = len(prompt_token_ids)
Expand Down
13 changes: 7 additions & 6 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from typing_extensions import TypeVar, deprecated

from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
BeamSearchSequence,
create_sort_beams_key_function)
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
Expand Down Expand Up @@ -573,10 +574,11 @@ def beam_search(
lora_requests = self._get_beam_search_lora_requests(
lora_request, prompts)

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer()
sort_beams_key = create_sort_beams_key_function(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small refactor to share this between sync / async because it's identical

tokenizer.eos_token_id,
length_penalty,
)

def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt:
Expand All @@ -591,7 +593,6 @@ def create_tokens_prompt_from_beam(
"mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)

tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
Expand Down