From 7a680db75e2bb432480ae64922353e7790b7f137 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 6 Sep 2025 15:20:55 +0800 Subject: [PATCH] migrate remain ViT to use disable_tp Signed-off-by: Isotr0py --- .../models/idefics2_vision_model.py | 85 +++++-------------- vllm/model_executor/models/mllama4.py | 32 +++---- vllm/model_executor/models/qwen2_5_vl.py | 43 +++++----- 3 files changed, 54 insertions(+), 106 deletions(-) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 0ca2e9e4bb68..e43f74e769ba 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -139,37 +138,22 @@ def __init__( assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size - if use_data_parallel: - self.q_size = self.num_heads * self.head_dim - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = ReplicatedLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -198,23 +182,21 @@ def __init__( super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -386,30 +368,6 @@ def forward( last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def _consolidate_qkv_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - qkv_idx_mappings = { - ".self_attn.q_proj": 0, - ".self_attn.k_proj": 1, - ".self_attn.v_proj": 2, - } - qkv_weights = {} - for name, loaded_weight in weights: - for weight_name, idx in qkv_idx_mappings.items(): - if weight_name not in name: - continue - new_name = name.replace(weight_name, ".self_attn.qkv_proj") - if new_name not in qkv_weights: - qkv_weights[new_name] = [None] * 3 - qkv_weights[new_name][idx] = loaded_weight - break - else: - yield name, loaded_weight - for key, weight in qkv_weights.items(): - qkv_weight = torch.cat(weight, dim=0) - yield key, qkv_weight - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -422,9 +380,6 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() layer_count = len(self.encoder.layers) - if self.use_data_parallel: - weights = self._consolidate_qkv_weights(weights) - for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index ecbbb5f57bec..6dc77666e158 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -106,22 +106,21 @@ def __init__( use_data_parallel: bool = False, ): super().__init__() - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( input_size=intermediate_size, output_size=output_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) self.activation_fn = nn.GELU() self.output_activation = output_activation @@ -419,20 +418,15 @@ def __init__( kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - params = { - "input_size": - config.num_channels * kernel_size[0] * kernel_size[1], - "output_size": config.hidden_size, - "bias": False, - "quant_config": quant_config, - "prefix": f"{prefix}.linear", - } - if use_data_parallel: - cls = ReplicatedLinear - else: - cls = ColumnParallelLinear - params["gather_output"] = True - self.linear = cls(**params) + self.linear = ColumnParallelLinear( + input_size=config.num_channels * kernel_size[0] * kernel_size[1], + output_size=config.hidden_size, + bias=False, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.linear", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0a89f86fc738..2c7085dfb7ef 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) # yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig @@ -468,32 +467,32 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.mlp = nn.ModuleList([ - cls_fc1(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + return_bias=False, + disable_tp=use_data_parallel, + ), nn.GELU(), - cls_fc2(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + return_bias=False, + disable_tp=use_data_parallel, + ), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) - - mlp_fc1, mlp_act, mlp_fc2 = self.mlp - x_parallel, _ = mlp_fc1(x) - x_parallel = mlp_act(x_parallel) - out, _ = mlp_fc2(x_parallel) + out = self.mlp(x) return out