From 1be55f65c78e30d69b4094c32578f03ec0512a37 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 12 Sep 2024 01:13:25 +0800 Subject: [PATCH 1/4] fix internvl batching --- vllm/model_executor/models/internvl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 81819578a4d8c..40645e287e37f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -449,11 +449,12 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - + # We need to flatten (B, N, P) to (B*N*P), + # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True).flatten(0, 1)), + flatten_bn(flatten_bn(pixel_values), concat=True)), ) raise AssertionError("This line should be unreachable.") From 35eec92fa0c760b79d65f233fa5193fbf3b394d7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 12 Sep 2024 01:30:11 +0800 Subject: [PATCH 2/4] fix internvl multi-images with dynamic size --- vllm/model_executor/models/internvl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 40645e287e37f..507d7014714a2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) elif is_list_of(data, Image.Image): + # we can't stack here because the images may have different num_patches data = [ image_to_pixel_values(img, image_size, @@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): max_num, use_thumbnail=use_thumbnail) for img in data ] - data = torch.stack(data) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) From df69d6138515de452cbf54d3314285da21b61042 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 12 Sep 2024 12:38:02 +0800 Subject: [PATCH 3/4] add test --- tests/models/test_internvl.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index fa3369dc53345..2049a5f8d53e4 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -331,6 +331,41 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, ) +@pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"]) +@pytest.mark.parametrize("size_factors", [[0.5, 1.0]]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_different_num_patches(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image.resize((1000, 1000)) for asset in image_assets] + + inputs_batching = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + inputs_multi_images = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + for inputs in [inputs_batching, inputs_multi_images]: + run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=2, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize( "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) @pytest.mark.parametrize( From 580daa02fbf272856f0e0aa53703bfd9d8a96658 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 12 Sep 2024 13:33:26 +0800 Subject: [PATCH 4/4] reduce image size --- tests/models/test_internvl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 2049a5f8d53e4..881068b3afe41 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -340,7 +340,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, def test_different_num_patches(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: - images = [asset.pil_image.resize((1000, 1000)) for asset in image_assets] + images = [asset.pil_image.resize((896, 896)) for asset in image_assets] inputs_batching = [( [prompt for _ in size_factors],