From a3691b6b5eb7e60039a8ff34550be5a7e8365394 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 8 Oct 2024 08:12:56 -0600 Subject: [PATCH] [Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131) Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 1 + tests/multimodal/test_processor_kwargs.py | 110 +++++++++++------- tests/test_inputs.py | 26 +++++ tests/test_utils.py | 32 ++++- vllm/core/scheduler.py | 1 + vllm/engine/llm_engine.py | 7 ++ vllm/entrypoints/llm.py | 9 ++ vllm/inputs/data.py | 67 +++++++++-- vllm/inputs/preprocess.py | 70 ++++++++--- vllm/inputs/registry.py | 13 ++- vllm/multimodal/audio.py | 4 +- vllm/multimodal/base.py | 31 +++-- vllm/multimodal/image.py | 24 +++- vllm/multimodal/registry.py | 13 ++- vllm/multimodal/video.py | 24 ++-- vllm/sequence.py | 14 +++ vllm/utils.py | 95 ++++++++++++--- vllm/worker/cpu_model_runner.py | 8 +- vllm/worker/model_runner.py | 4 +- vllm/worker/neuron_model_runner.py | 5 +- vllm/worker/openvino_model_runner.py | 6 +- 21 files changed, 443 insertions(+), 121 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index efad7e33793df..5dd539c3d5ee4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str): trust_remote_code=True, max_model_len=4096, max_num_seqs=2, + # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5529ccd4fa570..efc6903c373b6 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -74,11 +74,11 @@ def mm_model_cls(): # lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { - "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) + "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) } -### Test for default processor logic & mm_processor_kwargs wrapping +### Tests for default processor logic & mm_processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() @@ -89,23 +89,46 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_processor_default_kwargs(use_processor_mock, num_crops): - """Ensure input processors can use processor kwargs.""" - dummy_registry = InputRegistry() +def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): + """Get the init / inference kwargs and expected num_crops for this test.""" # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + init_kwargs = None if init_num_crops is None else { + "num_crops": init_num_crops } - expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - ctx = build_model_context(DUMMY_MODEL_ID, - mm_processor_kwargs=mm_processor_kwargs) - processor = dummy_registry.create_input_processor(ctx.model_config) + inference_kwargs = None if inference_num_crops is None else { + "num_crops": inference_num_crops + } + if inference_num_crops is not None: + expected_seq_count = inference_num_crops + elif init_num_crops is not None: + expected_seq_count = init_num_crops + else: + expected_seq_count = DEFAULT_NUM_CROPS + return init_kwargs, inference_kwargs, expected_seq_count + + +@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ + (None, None), + (NUM_CROPS_OVERRIDE, None), + (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), +]) +def test_input_processor_kwargs(use_processor_mock, init_num_crops, + inference_num_crops): + """Ensure input processors can use processor kwargs.""" + dummy_registry = InputRegistry() + + init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( + init_num_crops, inference_num_crops) - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) - assert num_crops_val == expected_num_crops + ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) + num_crops_val = processor( + LLMInputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) + assert num_crops_val == expected_seq_count @pytest.mark.parametrize( @@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, mm_processor_kwargs): """Ensure that input processors filter out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() + # Should filter out the init time kwargs ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + # Should filter out the inference time kwargs + num_crops_val = processor( + LLMInputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=mm_processor_kwargs)) assert num_crops_val == DEFAULT_NUM_CROPS @@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_custom_mapper_kwarg_overrides(image_assets, num_crops): +@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ + (None, None), + (NUM_CROPS_OVERRIDE, None), + (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), +]) +def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, + inference_num_crops): """Ensure custom mappers can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops - } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( + init_num_crops, inference_num_crops) + ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, + mm_processor_kwargs=init_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) - # Patch the image registry for phi3v with our lambda that is compatible - # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} - with patch.object( - mm_registry._get_plugin("image"), - "_default_input_mapper", - {mm_model_cls(): custom_mapper}, - ): - mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( + mm_model_cls()) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, + inference_kwargs) assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 @@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): def test_custom_mapper_with_sad_kwarg_overrides(image_assets, mm_processor_kwargs): """Ensure that custom mappers filters out invalid mm_processor_kwargs""" + # Should filter out the init time kwargs ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, @@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) - # Patch the image registry for phi3v with our lambda that is compatible - # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} - with patch.object( - mm_registry._get_plugin("image"), - "_default_input_mapper", - {mm_model_cls(): custom_mapper}, - ): - mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( + mm_model_cls()) + # Should filter out the inference time kwargs + mapped_inputs = mm_registry.map_input( + ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 3725d8687f255..fff7c5fc04285 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2,6 +2,7 @@ import pytest +from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_and_batch_prompt STRING_INPUTS = [ @@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]): def test_parse_single_batch_string_slice(inputs_slice: slice): assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) + + +# yapf: disable +@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [ + (None, [{}, {}]), + ({}, [{}, {}]), + ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), + ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), +]) +# yapf: enable +def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): + """Test mm_processor_kwargs init for zipping enc/dec prompts.""" + encoder_prompts = ['An encoder prompt', 'Another encoder prompt'] + decoder_prompts = ['A decoder prompt', 'Another decoder prompt'] + zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts, + mm_processor_kwargs) + assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts) + for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts, + expected_mm_kwargs, + zipped_prompts): + assert isinstance(zipped, dict) + assert len(zipped.keys()) == 3 + assert zipped['encoder_prompt'] == enc + assert zipped['decoder_prompt'] == dec + assert zipped['mm_processor_kwargs'] == exp_kwargs diff --git a/tests/test_utils.py b/tests/test_utils.py index f3017a8582ea8..268e6f8194abb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,7 @@ import pytest from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, - get_open_port, merge_async_iterators) + get_open_port, merge_async_iterators, supports_kw) from .utils import error_on_warning @@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config): with pytest.raises(ValueError): parser_with_config.parse_args( ['serve', '--config', './data/test_config.yaml']) + + +# yapf: enable +@pytest.mark.parametrize( + "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", + [ + # Tests for positional argument support + (lambda foo: None, "foo", True, True, False), + (lambda foo: None, "foo", False, True, True), + # Tests for positional or keyword / keyword only + (lambda foo=100: None, "foo", True, True, False), + (lambda *, foo: None, "foo", False, True, True), + # Tests to make sure the names of variadic params are NOT supported + (lambda *args: None, "args", False, True, False), + (lambda **kwargs: None, "kwargs", False, True, False), + # Tests for if we allow var kwargs to add support + (lambda foo: None, "something_else", False, True, False), + (lambda foo, **kwargs: None, "something_else", False, True, True), + (lambda foo, **kwargs: None, "kwargs", True, True, False), + (lambda foo, **kwargs: None, "foo", True, True, False), + ]) +# yapf: disable +def test_supports_kw(callable,kw_name,requires_kw_only, + allow_var_kwargs,is_supported): + assert supports_kw( + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs + ) == is_supported diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5cdb490e305f5..e930f807280f0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1309,6 +1309,7 @@ def schedule( # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6372d4b5d2117..510ffac6f6892 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -811,6 +811,13 @@ def add_request( ) processed_inputs = self.input_processor(preprocessed_inputs) + # This is a bit of a hack - copy the mm_processor_kwargs that were + # used in the input processor to the processed output, since these + # kwargs are presumed to be immutable and the values should be aligned + # between the input processor (here) and the input mapper. + processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( + "mm_processor_kwargs") + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b0a8a66ec133f..7ad352cd87526 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -472,6 +472,7 @@ def chat( add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -501,6 +502,8 @@ def chat( continue_final_message: If True, continues the final message in the conversation instead of starting a new one. Cannot be `True` if `add_generation_prompt` is also `True`. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. Returns: A list of ``RequestOutput`` objects containing the generated @@ -522,6 +525,9 @@ def chat( tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. conversation, mm_data = parse_chat_messages( msgs, model_config, tokenizer) @@ -554,6 +560,9 @@ def chat( if mm_data is not None: prompt["multi_modal_data"] = mm_data + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + prompts.append(prompt) return self.generate( diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index dfbcf95264875..724cdd2e6e802 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,5 +1,5 @@ -from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, - Union) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Tuple, Union) from typing_extensions import NotRequired, TypedDict, TypeVar @@ -19,6 +19,14 @@ class TextPrompt(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + class TokensPrompt(TypedDict): """Schema for a tokenized prompt.""" @@ -32,6 +40,14 @@ class TokensPrompt(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ @@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): according to any of the :class:`SingletonPrompt` schemas, and are not required to have the same schema. - Only the encoder prompt may have multi-modal data. + Only the encoder prompt may have multi-modal data. mm_processor_kwargs + should be at the top-level, and should not be set in the encoder/decoder + prompts, since they are agnostic to the encoder/decoder. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, @@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] + mm_processor_kwargs: NotRequired[Dict[str, Any]] + PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ @@ -121,6 +141,14 @@ class LLMInputs(TypedDict): if the model supports it. """ + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + class EncoderDecoderLLMInputs(LLMInputs): """ @@ -152,22 +180,43 @@ class EncoderDecoderLLMInputs(LLMInputs): def build_explicit_enc_dec_prompt( encoder_prompt: _T1, decoder_prompt: Optional[_T2], + mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: - return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, - decoder_prompt=decoder_prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt, + mm_processor_kwargs=mm_processor_kwargs) def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], dec_prompts: Iterable[Optional[_T2]], + mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]], + Dict[str, Any]]] = None, ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of - :class:`ExplicitEncoderDecoderPrompt` instances. - """ + :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs + may also be provided; if a dict is passed, the same dictionary will be + used for every encoder/decoder prompt. If an iterable is provided, it will + be zipped with the encoder/decoder prompts. + """ + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + if isinstance(mm_processor_kwargs, Dict): + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_processor_kwargs) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) + ] return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) - for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs + ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) ] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d4474a10f542d..22adb1631d410 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing_extensions import assert_never @@ -20,9 +20,11 @@ logger = init_logger(__name__) PromptComponents = Tuple[Optional[str], List[int], - Optional["MultiModalDataDict"]] + Optional["MultiModalDataDict"], Optional[Dict[str, + Any]]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional["MultiModalDataDict"]] + Optional["MultiModalDataDict"], + Optional[Dict[str, Any]]] class InputPreprocessor: @@ -227,6 +229,7 @@ def _extract_prompt_components( * prompt * prompt_token_ids * multi_modal_data + * mm_processor_kwargs (request-level input processor/mapper overrides) ''' parsed = parse_singleton_prompt(prompt) @@ -239,10 +242,12 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = None + mm_processor_kwargs = None elif parsed["type"] == "tokens": prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") elif parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( @@ -251,10 +256,12 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) async def _extract_prompt_components_async( self, @@ -273,10 +280,12 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = None + mm_processor_kwargs = None elif parsed["type"] == "tokens": prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") elif parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( @@ -285,18 +294,21 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) def _build_enc_dec_llm_inputs( self, encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, + mm_processor_kwargs: Dict[str, Any], ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps if decoder_mm_data is not None: raise ValueError( @@ -314,6 +326,7 @@ def _build_enc_dec_llm_inputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, multi_modal_data=decoder_mm_data, + mm_processor_kwargs=mm_processor_kwargs, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, encoder_multi_modal_data=encoder_mm_data, @@ -367,21 +380,30 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: decoder_comps = self._extract_prompt_components( decoder_input, request_id=request_id, ) + # Handle this carefully in case it was directly initialized by user + mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) else: encoder_comps = self._extract_prompt_components( prompt, request_id=request_id, ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} + decoder_comps = None, None, None, None + + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) async def _process_encoder_decoder_prompt_async( self, @@ -400,7 +422,7 @@ async def _process_encoder_decoder_prompt_async( if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: decoder_task = self._extract_prompt_components_async( decoder_input, @@ -409,29 +431,39 @@ async def _process_encoder_decoder_prompt_async( encoder_comps, decoder_comps = await asyncio.gather( encoder_task, decoder_task) + mm_processor_kwargs = prompt["mm_processor_kwargs"] else: encoder_comps = await self._extract_prompt_components_async( prompt, request_id=request_id, ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} + decoder_comps = None, None, None, None + + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps + (prompt, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) def _process_decoder_only_prompt( self, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 590ff54aea560..5bd3e1c86f66c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -9,7 +9,8 @@ from typing_extensions import TypeVar from vllm.logger import init_logger -from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once +from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, + resolve_mm_processor_kwargs) from .data import LLMInputs @@ -293,8 +294,14 @@ def process_input(self, model_config: "ModelConfig", model_cls, _ = get_model_architecture(model_config) processor = self._get_model_input_processor(model_cls) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - processor, overrides=model_config.mm_processor_kwargs) + # Handle multimodal processor kwargs with priority: + # Inference kwargs -> Init kwargs -> {} + # If it's empty, it'll fall back to the default kwarg values + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + inputs.get("mm_processor_kwargs"), + processor, + ) return processor(InputContext(model_config), inputs, **mm_processor_kwargs) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index b4bf4b4541db8..04d71826f29fa 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: + def _default_input_mapper(self, ctx: InputContext, data: object, + **mm_processor_kwargs) -> MultiModalInputs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 8bcb38ef241ed..84e71cbf60df7 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,7 +1,7 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, +from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypedDict, TypeVar, Union, cast, final) import numpy as np @@ -15,7 +15,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, - json_map_leaves) + json_map_leaves, resolve_mm_processor_kwargs) logger = init_logger(__name__) @@ -200,6 +200,7 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: """ Return a dictionary to be passed as keyword arguments to @@ -243,7 +244,8 @@ def wrapper(model_cls: N) -> N: return wrapper def map_input(self, model_config: ModelConfig, - data: MultiModalData[object]) -> MultiModalInputs: + data: MultiModalData[object], + mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. @@ -263,19 +265,26 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) - # Only get processor kwargs at mapping time if we are not using the - # input mapper; no overrides are used on the default here because they - # should be passed to the huggingface resource at initialization time. - if mapper is not None and mapper != self._default_input_mapper: - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - mapper, overrides=model_config.mm_processor_kwargs) - else: - mm_processor_kwargs = {} if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") + # In the case of the default mapper, we have to get resource + # processor through its HuggingFace autoclass; since this goes + # through **kwargs, we can't inspect it the same way, so we allow + # drop mm_processor_kwargs based on signature inspection + # if we're using the default mapper. + # + # This should be safe in general due to the sanitation, since the + # transformers resource should filter unused kwargs anyway. + uses_default_mapper = mapper == self._default_input_mapper + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + mm_processor_kwargs, + callable=mapper, + allow_var_kwargs=uses_default_mapper, + ) return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 7ca64152e481a..5f74bcea65ce2 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Any, Dict, Optional import torch from PIL import Image @@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig): - mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None - else model_config.mm_processor_kwargs) - # We don't explicitly check kwarg overrides to the HF class - # since the automodel just takes kwargs, so we can't inspect it + def _get_hf_image_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -37,6 +40,7 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: model_config = ctx.model_config @@ -46,12 +50,20 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor(model_config) + image_processor = self._get_hf_image_processor( + model_config, + mm_processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") try: + # NOTE: It may make sense to forward the mm_processor_kwargs + # here too. For now, to keep it simple, we only allow it be + # used for the initialization call though, just in case the + # signatures of the preprocessor initializer don't match + # preprocess() batch_data = image_processor \ .preprocess(data, return_tensors="pt") \ .data diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3940e1671b57a..5e9b8bd518de3 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,6 +1,6 @@ import functools from collections import UserDict -from typing import Dict, Mapping, Optional, Sequence +from typing import Any, Dict, Mapping, Optional, Sequence from vllm.config import ModelConfig from vllm.logger import init_logger @@ -96,8 +96,12 @@ def register_image_input_mapper( """ return self.register_input_mapper("image", mapper) - def map_input(self, model_config: ModelConfig, - data: MultiModalDataDict) -> MultiModalInputs: + def map_input( + self, + model_config: ModelConfig, + data: MultiModalDataDict, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. @@ -123,7 +127,8 @@ def map_input(self, model_config: ModelConfig, f"`--limit-mm-per-prompt`, but found {num_items} items " "in the same prompt.") - input_dict = plugin.map_input(model_config, data_value) + input_dict = plugin.map_input(model_config, data_value, + mm_processor_kwargs) for input_key, input_tensor in input_dict.items(): if input_key in merged_dict: raise ValueError(f"The input mappers (keys={set(data)}) " diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 39e75dbaf6872..4a9dbf20c8ec5 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import List, Union +from typing import Any, Dict, List, Optional, Union import numpy as np @@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin): def get_data_key(self) -> str: return "video" - def _get_hf_video_processor(self, model_config: ModelConfig): - mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None - else model_config.mm_processor_kwargs) - # We don't explicitly check kwarg overrides to the HF class - # since the automodel just takes kwargs, so we can't inspect it + def _get_hf_video_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} return cached_get_video_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -50,16 +52,24 @@ def _default_input_mapper( self, ctx: InputContext, data: MultiModalData[object], + **mm_processor_kwargs, ) -> MultiModalInputs: model_config = ctx.model_config # single video input as np.ndarray if isinstance(data, np.ndarray): - video_processor = self._get_hf_video_processor(model_config) + video_processor = self._get_hf_video_processor( + model_config, + mm_processor_kwargs, + ) if video_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") try: + # NOTE: Similar to image; it may be a good idea to filter and + # pass mm_processor_kwargs here too, but for now we don't to + # avoid extra complexity if the initializer and preprocess + # signatures of the processor don't align batch_data = video_processor(data, return_tensors="pt").data except Exception: logger.error("Failed to process image (%s)", data) diff --git a/vllm/sequence.py b/vllm/sequence.py index 9116408a001ff..0c27ffca36cfd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -481,6 +481,10 @@ def multi_modal_data(self) -> "MultiModalDataDict": EncoderDecoderLLMInputs, inputs).get("encoder_multi_modal_data")) or {} + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + return self.inputs.get("mm_processor_kwargs") or {} + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -710,6 +714,14 @@ def multi_modal_data(self) -> "MultiModalDataDict": # We use the multi-modal data of an arbitrary sequence. return self.seqs[0].multi_modal_data + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + # As with multi-modal data, all sequences in the group should have the + # same processor kwargs (i.e., mm_processor_kwargs are optionally + # provided per request; note that are independent of whether the model + # decoder-only or an encoder-decoder). + return self.seqs[0].mm_processor_kwargs + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -949,6 +961,7 @@ class SequenceGroupMetadata( used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + mm_processor_kwargs: Multimodal input processor / mapper overrides. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder @@ -975,6 +988,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/utils.py b/vllm/utils.py index bec2f951d69db..314fec0a65c7b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) -def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: +def supports_kw( + callable: Callable[..., object], + kw_name: str, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ params = inspect.signature(callable).parameters - if kw_name in params: - return True + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if (requires_kw_only and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + return False + if ((requires_kw_only + and param_val.kind == inspect.Parameter.KEYWORD_ONLY) + or (not requires_kw_only and is_sig_param)): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return (last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name) + return False + + +def resolve_mm_processor_kwargs( + init_kwargs: Optional[Dict[str, Any]], + inference_kwargs: Optional[Dict[str, Any]], + callable: Callable[..., object], + allow_var_kwargs: bool = False, +) -> Dict[str, Any]: + """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., + those who are not explicit keywords to the given callable (of one is + given; otherwise no filtering is done), then merges the kwarg dicts, + giving priority to inference_kwargs if there are any collisions. + + In the case that no kwarg overrides are provided, returns an empty + dict so that it can still be kwarg expanded into the callable later on. + + If allow_var_kwargs=True, allows for things that can be expanded into + kwargs as long as they aren't naming collision for var_kwargs or potential + positional arguments. + """ + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, + overrides=inference_kwargs, + allow_var_kwargs=allow_var_kwargs) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) - return any(param.kind == inspect.Parameter.VAR_KEYWORD - for param in params.values()) + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs def get_allowed_kwarg_only_overrides( callable: Callable[..., object], overrides: Optional[Dict[str, Any]], + allow_var_kwargs: bool = False, ) -> Dict[str, Any]: """ Given a callable which has one or more keyword only params and a dict @@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides( Args: callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. Returns: Dictionary containing the kwargs to be leveraged which may be used @@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides( if not overrides: return {} - allowed_override_names = [ - name for name, param in inspect.signature(callable).parameters.items() - if param.kind == inspect.Parameter.KEYWORD_ONLY - ] - - # Drop any mm_processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if kwarg_name in allowed_override_names + if supports_kw(callable, + kwarg_name, + requires_kw_only=True, + allow_var_kwargs=allow_var_kwargs) } # If anything is dropped, log a warning diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a03c562532179..f67b086796411 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -148,8 +148,9 @@ def build(self) -> ModelInputForCPU: ) def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, - computed_len: int): - mm_kwargs = self.multi_modal_input_mapper(mm_data) + computed_len: int, + mm_processor_kwargs: Dict[str, Any]): + mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) # special processing for mrope position deltas. mrope_positions = None @@ -210,7 +211,8 @@ def _prepare_prompt( mrope_positions = None if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs, mrope_positions = self._compute_multi_modal_input( - seq_data, mm_data, computed_len) + seq_data, mm_data, computed_len, + seq_group_metadata.mm_processor_kwargs) multi_modal_inputs_list.append(mm_kwargs) # Token position ids diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9784438841980..0bd2958816718 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -640,7 +640,9 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) inter_data.multi_modal_inputs = mm_kwargs # special processing for mrope position deltas. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 44d4845a838ef..b8c760c4b5396 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -153,7 +153,10 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, + ) multi_modal_inputs_list.append(mm_kwargs) max_seq_len = max(seq_lens) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 77ee2eadf29a2..de3088695dfef 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -172,7 +172,11 @@ def _prepare_model_input( mm_data = seq_group_metadata.multi_modal_data if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) multi_modal_inputs_list.append(mm_kwargs) block_table = seq_group_metadata.block_tables[seq_id]