From bafac136317c26f0f593d47ba78ea903744fe085 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 11 Sep 2024 10:11:01 +0800 Subject: [PATCH] [Bugfix] Fix InternVL2 vision embeddings process with pipeline parallel (#8299) --- tests/distributed/test_pipeline_parallel.py | 10 ++++++++-- vllm/model_executor/models/internvl.py | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 637d2b30f6b1f..d2219eed988e1 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,7 +32,9 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), + (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), ], ) @fork_new_process_for_each_test @@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, # use half precision for speed and memory savings in CI environment "--dtype", "float16", + "--max-model-len", + "8192", "--pipeline-parallel-size", str(PP_SIZE), "--tensor-parallel-size", @@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, tp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + "float16", + "--max-model-len", + "8192", "--tensor-parallel-size", str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. "--distributed-executor-backend", diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0cf63d9e1fb22..81819578a4d8c 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -480,7 +481,7 @@ def forward( **kwargs: object, ) -> SamplerOutput: image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: + if image_input is not None and get_pp_group().is_first_rank: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) vision_embeddings = self._process_image_input(image_input)