From 4dd5028ab3fb110ca9fe3b540b12e232de0147d9 Mon Sep 17 00:00:00 2001 From: zzh142857 Date: Mon, 11 Aug 2025 20:09:21 -0700 Subject: [PATCH 1/2] add DP to step3 vision encoder Signed-off-by: zzh142857 --- vllm/model_executor/models/step3_vl.py | 99 +++++++++++++++++++------- 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 41dba312cb42..0053f7f77c27 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -15,13 +15,14 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType - +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) + RowParallelLinear, + ReplicatedLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -650,7 +651,8 @@ class Step3VisionAttention(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -659,20 +661,44 @@ def __init__(self, self.scale = self.head_dim**-0.5 - tp_size = get_tensor_model_parallel_world_size() + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.qkv_proj = QKVParallelLinear(self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix) + + self.q_size = self.num_heads * self.head_dim + + if use_data_parallel: + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3*self.q_size, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + self.out_proj = ReplicatedLinear( + self.total_num_heads * self.head_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -712,16 +738,21 @@ class Step3VisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1(config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=prefix) - self.fc2 = RowParallelLinear(config.intermediate_size, + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2(config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, @@ -739,15 +770,17 @@ class Step3VisionEncoderLayer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() + self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size self.self_attn = Step3VisionAttention(config, quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", use_data_parallel=self.use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") + self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp", use_data_parallel=self.use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -767,13 +800,16 @@ class Step3VisionEncoder(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.layers = nn.ModuleList([ Step3VisionEncoderLayer(config, quant_config, - prefix=f"{prefix}.layers.{i}") + prefix=f"{prefix}.layers.{i}", + use_data_parallel=self.use_data_parallel) for i in range(config.num_hidden_layers) ]) @@ -792,21 +828,27 @@ class Step3VisionTransformer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.image_size = config.image_size self.embeddings = Step3VisionEmbeddings(config) self.transformer = Step3VisionEncoder(config, quant_config, - prefix=f"{prefix}.transformer") + prefix=f"{prefix}.transformer", + use_data_parallel=self.use_data_parallel) def forward( self, pixel_values: torch.Tensor, ): hidden_states = self.embeddings(pixel_values) - hidden_states = self.transformer(inputs_embeds=hidden_states) + if self.use_data_parallel: + hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer) + else: + hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states @@ -836,13 +878,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = (vllm_config.parallel_config. + enable_multimodal_encoder_data_parallel) if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer(config.vision_config, None, prefix=maybe_prefix( prefix, - "vision_model")) + "vision_model"), + use_data_parallel=self.use_data_parallel) self.vit_downsampler = nn.Conv2d( config.vision_config.hidden_size, config.vision_config.output_hidden_size, From c3163863604513284669da8c2a25ed9071ba56cf Mon Sep 17 00:00:00 2001 From: zzh142857 Date: Tue, 12 Aug 2025 06:47:19 -0700 Subject: [PATCH 2/2] fix pre-commit Signed-off-by: zzh142857 --- vllm/model_executor/models/step3_vl.py | 87 ++++++++++++++------------ 1 file changed, 46 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 0053f7f77c27..f1f38c01b784 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -15,14 +15,14 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType -from vllm.multimodal.utils import run_dp_sharded_vision_model + from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - RowParallelLinear, - ReplicatedLinear) + ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -34,6 +34,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -651,7 +652,7 @@ class Step3VisionAttention(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + prefix: str = "", use_data_parallel: bool = False): super().__init__() self.config = config @@ -662,16 +663,16 @@ def __init__(self, self.scale = self.head_dim**-0.5 tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + get_tensor_model_parallel_world_size()) assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.q_size = self.num_heads * self.head_dim - + if use_data_parallel: self.qkv_proj = ReplicatedLinear( self.embed_dim, - 3*self.q_size, + 3 * self.q_size, bias=True, quant_config=quant_config, prefix=prefix, @@ -692,13 +693,11 @@ def __init__(self, quant_config=quant_config, prefix=prefix, ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix - ) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -738,7 +737,7 @@ class Step3VisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + prefix: str = "", use_data_parallel: bool = False): super().__init__() self.config = config @@ -746,17 +745,17 @@ def __init__(self, cls_fc1 = (ReplicatedLinear if use_data_parallel else ColumnParallelLinear) self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=prefix) cls_fc2 = (ReplicatedLinear if use_data_parallel else RowParallelLinear) self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=prefix) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -770,17 +769,22 @@ class Step3VisionEncoderLayer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + prefix: str = "", use_data_parallel: bool = False): super().__init__() self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size - self.self_attn = Step3VisionAttention(config, - quant_config, - prefix=f"{prefix}.self_attn", use_data_parallel=self.use_data_parallel) + self.self_attn = Step3VisionAttention( + config, + quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=self.use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp", use_data_parallel=self.use_data_parallel) + self.mlp = Step3VisionMLP(config, + quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -808,7 +812,7 @@ def __init__(self, self.layers = nn.ModuleList([ Step3VisionEncoderLayer(config, quant_config, - prefix=f"{prefix}.layers.{i}", + prefix=f"{prefix}.layers.{i}", use_data_parallel=self.use_data_parallel) for i in range(config.num_hidden_layers) ]) @@ -828,17 +832,18 @@ class Step3VisionTransformer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + prefix: str = "", use_data_parallel: bool = False): super().__init__() self.config = config self.use_data_parallel = use_data_parallel self.image_size = config.image_size self.embeddings = Step3VisionEmbeddings(config) - self.transformer = Step3VisionEncoder(config, - quant_config, - prefix=f"{prefix}.transformer", - use_data_parallel=self.use_data_parallel) + self.transformer = Step3VisionEncoder( + config, + quant_config, + prefix=f"{prefix}.transformer", + use_data_parallel=self.use_data_parallel) def forward( self, @@ -846,7 +851,8 @@ def forward( ): hidden_states = self.embeddings(pixel_values) if self.use_data_parallel: - hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer) + hidden_states = run_dp_sharded_vision_model( + hidden_states, self.transformer) else: hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states @@ -882,12 +888,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: enable_multimodal_encoder_data_parallel) if multimodal_config.get_limit_per_prompt("image"): - self.vision_model = Step3VisionTransformer(config.vision_config, - None, - prefix=maybe_prefix( - prefix, - "vision_model"), - use_data_parallel=self.use_data_parallel) + self.vision_model = Step3VisionTransformer( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel) self.vit_downsampler = nn.Conv2d( config.vision_config.hidden_size, config.vision_config.output_hidden_size,