From 03b94add47164eb1d0609116dcab96ffa3af153d Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Fri, 29 Aug 2025 00:52:26 -0700 Subject: [PATCH 1/7] DP support for InternVL vision encoder Signed-off-by: Yiwen Chen --- docs/configuration/optimization.md | 1 + vllm/model_executor/models/intern_vit.py | 105 ++++++++++++++++------- vllm/model_executor/models/internvl.py | 5 +- 3 files changed, 78 insertions(+), 33 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 2d8cdcc11fa9..6a9caf8a4341 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -178,6 +178,7 @@ Known supported models: - MiniCPM-V-2.5 or above (, ) - Qwen2.5-VL () - Step3 () +- InternVL ## Input Processing diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 58e8163e0b26..fb0b88dcc8bf 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -25,9 +25,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model NORM2FN = { 'rms_norm': RMSNorm, @@ -137,6 +139,7 @@ def __init__( *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -150,8 +153,10 @@ def __init__( f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' f' {self.num_heads}).') - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.tp_rank = (0 if use_data_parallel else + get_tensor_model_parallel_rank()) # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim @@ -159,14 +164,23 @@ def __init__( self.tp_size) self.scale = self.head_dim**-0.5 - self.qkv = QKVParallelLinear( - self.embed_dim, - self.head_dim, - num_dummy_heads + self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) + if use_data_parallel: + self.qkv = ReplicatedLinear( + self.embed_dim, + 3 * self.head_dim * self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + else: + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) self.qk_normalization = config.qk_normalization @@ -178,12 +192,20 @@ def __init__( eps=config.layer_norm_eps, var_hidden_size=self.embed_dim) - self.proj = RowParallelLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + if use_data_parallel: + self.proj = ReplicatedLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + else: + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -286,21 +308,26 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -319,6 +346,7 @@ def __init__( *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -329,11 +357,13 @@ def __init__( self.attn = self._init_attn(config, quant_config, num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = InternMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, @@ -351,16 +381,18 @@ def _init_attn( *, num_dummy_heads: int, prefix: str = "", + use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable tp_size = get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads - if (num_heads + num_dummy_heads) % tp_size == 0: + if (num_heads + num_dummy_heads) % tp_size == 0 or use_data_parallel: return InternParallelAttention(config, quant_config=quant_config, num_dummy_heads=num_dummy_heads, - prefix=prefix) + prefix=prefix, + use_data_parallel=use_data_parallel) return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) @@ -387,6 +419,7 @@ def __init__( num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() @@ -401,7 +434,8 @@ def __init__( InternVisionEncoderLayer(config, quant_config, num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -428,10 +462,12 @@ def __init__( num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = InternVisionEmbeddings(config) self.encoder = InternVisionEncoder( @@ -440,6 +476,7 @@ def __init__( num_hidden_layers_override=num_hidden_layers_override, num_dummy_heads=num_dummy_heads, prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, ) def get_input_embeddings(self): @@ -463,7 +500,11 @@ def forward( raise ValueError( f'wrong pixel_values size: {pixel_values.shape}') - encoder_outputs = self.encoder(inputs_embeds=hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b09ed7bbe72a..b703e4254801 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1020,6 +1020,8 @@ def get_video_replacement_internvl(item_idx: int): class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -1038,6 +1040,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size @@ -1105,7 +1108,7 @@ def _init_vision_model( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, - ) + use_data_parallel=self.use_data_parallel) else: return InternVisionPatchModel(config.vision_config) From 399c3889a46bede9096521d2e407df073fd24adf Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Fri, 29 Aug 2025 00:58:22 -0700 Subject: [PATCH 2/7] add pr reference Signed-off-by: Yiwen Chen --- docs/configuration/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 6a9caf8a4341..b5d7a7d0a123 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -178,7 +178,7 @@ Known supported models: - MiniCPM-V-2.5 or above (, ) - Qwen2.5-VL () - Step3 () -- InternVL +- InternVL () ## Input Processing From 0e94805704a4241a8053ca0d40ac2bc2acdbd4b6 Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Fri, 29 Aug 2025 23:10:06 -0700 Subject: [PATCH 3/7] sort models Signed-off-by: Yiwen Chen --- docs/configuration/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index b5d7a7d0a123..5d26095ce89f 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -174,11 +174,11 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u Known supported models: +- InternVL () - Llama4 () - MiniCPM-V-2.5 or above (, ) - Qwen2.5-VL () - Step3 () -- InternVL () ## Input Processing From 25666a292b47f89a77e0eb6cb6d8ac406b4e2daa Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Sat, 30 Aug 2025 23:22:47 -0700 Subject: [PATCH 4/7] add logging Signed-off-by: Yiwen Chen --- vllm/model_executor/models/intern_vit.py | 7 +++++-- vllm/model_executor/models/internvl.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index fb0b88dcc8bf..e16e77042ac9 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -165,6 +165,7 @@ def __init__( self.scale = self.head_dim**-0.5 if use_data_parallel: + print("******using data parallel in InternParallelAttention") self.qkv = ReplicatedLinear( self.embed_dim, 3 * self.head_dim * self.num_heads, @@ -384,10 +385,12 @@ def _init_attn( use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - tp_size = get_tensor_model_parallel_world_size() + # tp_size = get_tensor_model_parallel_world_size() + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) num_heads = config.num_attention_heads - if (num_heads + num_dummy_heads) % tp_size == 0 or use_data_parallel: + if (num_heads + num_dummy_heads) % tp_size == 0: return InternParallelAttention(config, quant_config=quant_config, num_dummy_heads=num_dummy_heads, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b703e4254801..01fc7d711d54 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1040,7 +1040,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + print(f"******mm_encoder_tp_mode: {multimodal_config.mm_encoder_tp_mode}") self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + print(f"******use_data_parallel: {self.use_data_parallel}") self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size From 3539f14fcdca87608af1260c1be868a62d78d8ab Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Sat, 30 Aug 2025 23:49:59 -0700 Subject: [PATCH 5/7] remove logging Signed-off-by: Yiwen Chen --- vllm/model_executor/models/intern_vit.py | 3 +-- vllm/model_executor/models/internvl.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index e16e77042ac9..0e08cd2b9227 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -165,7 +165,6 @@ def __init__( self.scale = self.head_dim**-0.5 if use_data_parallel: - print("******using data parallel in InternParallelAttention") self.qkv = ReplicatedLinear( self.embed_dim, 3 * self.head_dim * self.num_heads, @@ -387,7 +386,7 @@ def _init_attn( # fallback to sdpa attention if tp unavailable # tp_size = get_tensor_model_parallel_world_size() tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + get_tensor_model_parallel_world_size()) num_heads = config.num_attention_heads if (num_heads + num_dummy_heads) % tp_size == 0: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 01fc7d711d54..b703e4254801 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1040,9 +1040,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config - print(f"******mm_encoder_tp_mode: {multimodal_config.mm_encoder_tp_mode}") self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - print(f"******use_data_parallel: {self.use_data_parallel}") self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size From 2544bf6a4c1bbdbb2da299bd00a162edd4d5a6f2 Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Tue, 9 Sep 2025 21:46:38 -0700 Subject: [PATCH 6/7] add logging Signed-off-by: Yiwen Chen --- vllm/model_executor/models/intern_vit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 0e08cd2b9227..312656d2e342 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -165,6 +165,7 @@ def __init__( self.scale = self.head_dim**-0.5 if use_data_parallel: + print("*******Using data parallel for attention") self.qkv = ReplicatedLinear( self.embed_dim, 3 * self.head_dim * self.num_heads, From 3951461f4c509c13a30f3c1f54e005395383db4d Mon Sep 17 00:00:00 2001 From: Yiwen Chen Date: Tue, 9 Sep 2025 21:53:50 -0700 Subject: [PATCH 7/7] remove logging Signed-off-by: Yiwen Chen --- vllm/model_executor/models/intern_vit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 312656d2e342..0e08cd2b9227 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -165,7 +165,6 @@ def __init__( self.scale = self.head_dim**-0.5 if use_data_parallel: - print("*******Using data parallel for attention") self.qkv = ReplicatedLinear( self.embed_dim, 3 * self.head_dim * self.num_heads,