From 4bd5b1ef859bf9ceda6b3a4f83748a61a5893853 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 10 Jun 2025 12:05:10 +0000 Subject: [PATCH 1/4] Use common beam search scoring Signed-off-by: Alex-Brooks --- vllm/entrypoints/llm.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c11e627ee236..cde8de257bb6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -15,7 +15,8 @@ from typing_extensions import TypeVar, deprecated from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, - BeamSearchSequence, get_beam_search_score) + BeamSearchSequence, + create_sort_beams_key_function) from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, is_init_field) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, @@ -573,10 +574,11 @@ def beam_search( lora_requests = self._get_beam_search_lora_requests( lora_request, prompts) - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, - tokenizer.eos_token_id, - length_penalty) + tokenizer = self.get_tokenizer() + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, + length_penalty, + ) def create_tokens_prompt_from_beam( beam: BeamSearchSequence) -> TokensPrompt: @@ -591,7 +593,6 @@ def create_tokens_prompt_from_beam( "mm_processor_kwargs"] = beam.mm_processor_kwargs return TokensPrompt(**token_prompt_kwargs) - tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa From 7f6066f478823a67cc2978a2d2fa9d6cd63b5bdc Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 16 Jun 2025 12:07:37 +0000 Subject: [PATCH 2/4] Pull prompt token ids and mm data off of raw prompt Signed-off-by: Alex-Brooks --- vllm/engine/protocol.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 727d59283643..8688fcc82cd9 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -88,9 +88,18 @@ async def beam_search( if processed_inputs["type"] == "embeds": raise NotImplementedError - prompt_token_ids = processed_inputs["prompt_token_ids"] + # This is a workaround to fix multimodal beam search; this is a + # bandaid fix for 2 small problems: + # 1. Multi_modal_data on the processed_inputs currently resolves to + # `None`. + # 2. preprocessing above expands the multimodal placeholders. However, + # this happens again in generation, so the double expansion causes + # a mismatch. + # TODO - would be ideal to handle this more gracefully. + prompt_token_ids = prompt.get("prompt_token_ids") + multi_modal_data = prompt.get("multi_modal_data") + prompt_text = processed_inputs.get("prompt") - multi_modal_data = processed_inputs.get("multi_modal_data") mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") tokenized_length = len(prompt_token_ids) From b88637477829a992585576017c0588175438c258 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 16 Jun 2025 13:01:44 +0000 Subject: [PATCH 3/4] Check responsese in vision async beam search test Signed-off-by: Alex-Brooks --- tests/entrypoints/openai/test_vision.py | 31 +++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 4513d8b3420f..cb8ab608990b 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -25,6 +25,25 @@ "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] +EXPECTED_MM_BEAM_SEARCH_RES = [ + [ + "The image shows a wooden boardwalk leading through a", + "The image shows a wooden boardwalk extending into a", + ], + [ + "The image shows two parrots perched on", + "The image shows two birds perched on a cur", + ], + [ + "The image shows a Venn diagram with three over", + "The image displays a Venn diagram with three over", + ], + [ + "This image displays a gradient of colors ranging from", + "This image displays a gradient of colors transitioning from", + ], +] + @pytest.fixture(scope="module") def server(): @@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS)))) async def test_single_chat_session_image_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, image_url: str, + client: openai.AsyncOpenAI, model_name: str, image_idx: int, base64_encoded_image: dict[str, str]): + # NOTE: This test also validates that we pass MM data through beam search + image_url = TEST_IMAGE_URLS[image_idx] + expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] messages = [{ "role": @@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch( messages=messages, n=2, max_completion_tokens=10, + temperature=0.0, extra_body=dict(use_beam_search=True)) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + for actual, expected_str in zip(chat_completion.choices, expected_res): + assert actual.message.content == expected_str @pytest.mark.asyncio From 2cf7fd8c73fd0d35b631ed02ada492a8c5784c27 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 16 Jun 2025 16:26:31 +0000 Subject: [PATCH 4/4] Fix beam search test output in ci Signed-off-by: Alex-Brooks --- tests/entrypoints/openai/test_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index cb8ab608990b..fd613842f986 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -36,7 +36,7 @@ ], [ "The image shows a Venn diagram with three over", - "The image displays a Venn diagram with three over", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from",