diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 41dba312cb42..f1f38c01b784 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -33,6 +34,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -659,20 +662,42 @@ def __init__(self, self.scale = self.head_dim**-0.5 - tp_size = get_tensor_model_parallel_world_size() + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.qkv_proj = QKVParallelLinear(self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix) + + self.q_size = self.num_heads * self.head_dim + + if use_data_parallel: + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + self.out_proj = ReplicatedLinear( + self.total_num_heads * self.head_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=prefix, + ) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=prefix) - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=prefix) + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=prefix) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() + self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size - self.self_attn = Step3VisionAttention(config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Step3VisionAttention( + config, + quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=self.use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") + self.mlp = Step3VisionMLP(config, + quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.layers = nn.ModuleList([ Step3VisionEncoderLayer(config, quant_config, - prefix=f"{prefix}.layers.{i}") + prefix=f"{prefix}.layers.{i}", + use_data_parallel=self.use_data_parallel) for i in range(config.num_hidden_layers) ]) @@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.image_size = config.image_size self.embeddings = Step3VisionEmbeddings(config) - self.transformer = Step3VisionEncoder(config, - quant_config, - prefix=f"{prefix}.transformer") + self.transformer = Step3VisionEncoder( + config, + quant_config, + prefix=f"{prefix}.transformer", + use_data_parallel=self.use_data_parallel) def forward( self, pixel_values: torch.Tensor, ): hidden_states = self.embeddings(pixel_values) - hidden_states = self.transformer(inputs_embeds=hidden_states) + if self.use_data_parallel: + hidden_states = run_dp_sharded_vision_model( + hidden_states, self.transformer) + else: + hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states @@ -836,13 +884,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = (vllm_config.parallel_config. + enable_multimodal_encoder_data_parallel) if multimodal_config.get_limit_per_prompt("image"): - self.vision_model = Step3VisionTransformer(config.vision_config, - None, - prefix=maybe_prefix( - prefix, - "vision_model")) + self.vision_model = Step3VisionTransformer( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel) self.vit_downsampler = nn.Conv2d( config.vision_config.hidden_size, config.vision_config.output_hidden_size,