Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configuration/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:

- Llama4 (<gh-pr:18368>)
- MiniCPM-V-4 (<gh-pr:23327>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)

Expand Down
123 changes: 96 additions & 27 deletions vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
Idefics2Config, Idefics2VisionConfig)

from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import run_dp_sharded_vision_model


class Idefics2VisionEmbeddings(nn.Module):
Expand Down Expand Up @@ -118,6 +120,7 @@ def __init__(
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
Expand All @@ -130,22 +133,43 @@ def __init__(
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size

if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to bother you, but I was wondering why it's 3 * self.q_size instead of self.head_dim?

Copy link
Contributor

@CarrotShoo CarrotShoo Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you misunderstood the meaning of the parameters? In ReplicatedRinear, 3 * self.q_size is the output_size, self.head_dim in QKVParallelLinear means head_size. Hope my answer has been helpful to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! Got it now.

bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

Expand All @@ -169,18 +193,23 @@ def __init__(
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
config.intermediate_size,
config.hidden_size,
bias=True,
Expand All @@ -202,17 +231,21 @@ def __init__(
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = Idefics2VisionAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)

Expand Down Expand Up @@ -254,6 +287,7 @@ def __init__(
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

Expand All @@ -267,7 +301,8 @@ def __init__(
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers)
])

Expand Down Expand Up @@ -301,17 +336,20 @@ def __init__(
num_hidden_layers_override: Optional[int] = None,
require_post_norm: bool = True,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

embed_dim = config.hidden_size
self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder")
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)

num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
Expand Down Expand Up @@ -340,10 +378,38 @@ def forward(
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
encoder_outputs = self.encoder(hidden_states)
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state

def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
Expand All @@ -356,6 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str,
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)

if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)

for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):
Expand All @@ -373,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str,
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
if weight_name not in name or self.use_data_parallel:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# and config class
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.version = get_version_by_config(self.config)
self.llm = self.init_llm(vllm_config=vllm_config,
Expand Down Expand Up @@ -1325,9 +1326,11 @@ def init_vision_module(
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=self.use_data_parallel)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
Expand Down
2 changes: 2 additions & 0 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
num_chunks_per_rank, ...]

vision_embeddings = vision_model(image_input_per_rank)
# Ensure tensor is contiguous before all_gather
vision_embeddings = vision_embeddings.contiguous()
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
dim=0)
vision_embeddings = vision_embeddings[:num_chunks, ...]
Expand Down