Skip to content

Commit

Permalink
[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#…
Browse files Browse the repository at this point in the history
…9131)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks authored Oct 8, 2024
1 parent 8c74622 commit a3691b6
Show file tree
Hide file tree
Showing 21 changed files with 443 additions and 121 deletions.
1 change: 1 addition & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
)
stop_token_ids = None
Expand Down
110 changes: 70 additions & 40 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def mm_model_cls():
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
}


### Test for default processor logic & mm_processor_kwargs wrapping
### Tests for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
Expand All @@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
assert proc_inputs is proc_outputs


@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_processor_default_kwargs(use_processor_mock, num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
"""Get the init / inference kwargs and expected num_crops for this test."""
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
init_kwargs = None if init_num_crops is None else {
"num_crops": init_num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
inference_kwargs = None if inference_num_crops is None else {
"num_crops": inference_num_crops
}
if inference_num_crops is not None:
expected_seq_count = inference_num_crops
elif init_num_crops is not None:
expected_seq_count = init_num_crops
else:
expected_seq_count = DEFAULT_NUM_CROPS
return init_kwargs, inference_kwargs, expected_seq_count


@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
inference_num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()

init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
init_num_crops, inference_num_crops)

num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count


@pytest.mark.parametrize(
Expand All @@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
# Should filter out the init time kwargs
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)

processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
# Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS


Expand Down Expand Up @@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1


@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
inference_num_crops):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
init_num_crops, inference_num_crops)

ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
inference_kwargs)

assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1

Expand All @@ -316,24 +346,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
# Should filter out the inference time kwargs
mapped_inputs = mm_registry.map_input(
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)

assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
26 changes: 26 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_and_batch_prompt

STRING_INPUTS = [
Expand Down Expand Up @@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])


# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
mm_processor_kwargs)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
expected_mm_kwargs,
zipped_prompts):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped['encoder_prompt'] == enc
assert zipped['decoder_prompt'] == dec
assert zipped['mm_processor_kwargs'] == exp_kwargs
32 changes: 31 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
get_open_port, merge_async_iterators)
get_open_port, merge_async_iterators, supports_kw)

from .utils import error_on_warning

Expand Down Expand Up @@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])


# yapf: enable
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
])
# yapf: disable
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,7 @@ def schedule(
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,13 @@ def add_request(
)
processed_inputs = self.input_processor(preprocessed_inputs)

# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")

self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def chat(
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
Expand Down Expand Up @@ -501,6 +502,8 @@ def chat(
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of ``RequestOutput`` objects containing the generated
Expand All @@ -522,6 +525,9 @@ def chat(
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)

Expand Down Expand Up @@ -554,6 +560,9 @@ def chat(
if mm_data is not None:
prompt["multi_modal_data"] = mm_data

if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs

prompts.append(prompt)

return self.generate(
Expand Down
Loading

0 comments on commit a3691b6

Please sign in to comment.