From dcd616f73878ee30a0d5e5a9f728553b1f0f1a0f Mon Sep 17 00:00:00 2001 From: yZhen Date: Mon, 19 May 2025 12:47:48 -0700 Subject: [PATCH 1/5] enable data parallel for L4 vision encoder Signed-off-by: yzhen --- vllm/config.py | 3 + vllm/engine/arg_utils.py | 7 + vllm/model_executor/models/mllama4.py | 248 +++++++++++++++++++------- 3 files changed, 189 insertions(+), 69 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f400e9875910..03e9b0ce425c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1790,6 +1790,9 @@ class is dynamically inherited by the worker class. This is used to inject rank: int = 0 """Global rank in distributed setup.""" + enable_vision_encoder_data_parallel: bool = False + """ Use data parallelism instead of tensor parallelism for vision encoder. Only support LLama4 for now""" + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 55553252630f..c94283327b55 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -423,6 +423,8 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location + enable_vision_encoder_data_parallel: bool = ParallelConfig.enable_vision_encoder_data_parallel + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -637,6 +639,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["worker_cls"]) parallel_group.add_argument("--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]) + parallel_group.add_argument( + "--enable-vision-encoder-data-parallel", + **parallel_kwargs["enable_vision_encoder_data_parallel"]) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -1078,6 +1083,8 @@ def create_engine_config( distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, + enable_vision_encoder_data_parallel=self. + enable_vision_encoder_data_parallel, ) speculative_config = self.create_speculative_config( diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 8c98492c0bed..e12ceb304688 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -30,10 +30,13 @@ from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -84,23 +87,28 @@ class Llama4ImagePatchInputs(TypedDict): class Llama4VisionMLP(nn.Module): - def __init__(self, - input_size: int, - intermediate_size: int, - output_size: int, - bias: bool, - output_activation: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + intermediate_size: int, + output_size: int, + bias: bool, + output_activation: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - self.fc1 = ColumnParallelLinear( + cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear + self.fc1 = cls_fc1( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear + self.fc2 = cls_fc2( input_size=intermediate_size, output_size=output_size, bias=bias, @@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio): int(channels / shuffle_ratio)) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - reshaped_tensor = reshaped_tensor.view(batch_size, - int(height * shuffle_ratio), - int(width * shuffle_ratio), - int(channels / (shuffle_ratio**2))) + reshaped_tensor = reshaped_tensor.view( + batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2)), + ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() output_tensor = reshaped_tensor.view(batch_size, -1, @@ -173,6 +183,7 @@ def __init__( config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio @@ -186,7 +197,9 @@ def __init__( bias=config.multi_modal_projector_bias, output_activation=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: encoded_patches = pixel_shuffle(encoded_patches, @@ -201,10 +214,12 @@ def __init__( config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -217,22 +232,39 @@ def __init__( self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, self.scaling) - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=True, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) + + if use_data_parallel: + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + self.q_size + 2 * self.kv_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = ReplicatedLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( head_size=self.head_dim, @@ -275,22 +307,29 @@ def __init__( config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.intermediate_size = config.intermediate_size - self.self_attn = Llama4VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Llama4VisionMLP(input_size=config.hidden_size, - intermediate_size=config.intermediate_size, - output_size=config.hidden_size, - bias=True, - output_activation=False, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.self_attn = Llama4VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = Llama4VisionMLP( + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + output_activation=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) self.input_layernorm = nn.LayerNorm(config.hidden_size) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) @@ -322,6 +361,7 @@ def __init__( config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config @@ -330,6 +370,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, ) for layer_idx in range(config.num_hidden_layers) ]) @@ -357,23 +398,33 @@ def forward( class Llama4UnfoldConvolution(nn.Module): - def __init__(self, - config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() kernel_size = config.patch_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - self.linear = ColumnParallelLinear(config.num_channels * - kernel_size[0] * kernel_size[1], - config.hidden_size, - bias=False, - quant_config=quant_config, - gather_output=True, - prefix=f"{prefix}.linear") + params = { + "input_size": + config.num_channels * kernel_size[0] * kernel_size[1], + "output_size": config.hidden_size, + "bias": False, + "quant_config": quant_config, + "prefix": f"{prefix}.linear", + } + if use_data_parallel: + cls = ReplicatedLinear + else: + cls = ColumnParallelLinear + params["gather_output"] = True + self.linear = cls(**params) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) @@ -389,6 +440,7 @@ def __init__( config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config @@ -403,7 +455,9 @@ def __init__( self.patch_embedding = Llama4UnfoldConvolution( config, quant_config=quant_config, - prefix=f"{prefix}.patch_embedding") + prefix=f"{prefix}.patch_embedding", + use_data_parallel=use_data_parallel, + ) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) @@ -415,11 +469,18 @@ def __init__( self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) # encoders - self.model = Llama4VisionEncoder(config, - quant_config=quant_config, - prefix=f"{prefix}.model") + self.model = Llama4VisionEncoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.model", + use_data_parallel=use_data_parallel, + ) self.vision_adapter = Llama4VisionPixelShuffleMLP( - config, quant_config, prefix=f"{prefix}.vision_adapter") + config, + quant_config, + prefix=f"{prefix}.vision_adapter", + use_data_parallel=use_data_parallel, + ) def forward( self, @@ -528,8 +589,9 @@ def _call_hf_processor( vision_config = self.info.get_hf_config().vision_config if processed_outputs.get("pixel_values") is not None: - assert "images" in mm_data, \ - "images expected to be in mm_data when pixel_values is present" + assert ( + "images" in mm_data + ), "images expected to be in mm_data when pixel_values is present" images = mm_data["images"] parsed_images = (self._get_data_parser().parse_mm_data({ @@ -546,8 +608,8 @@ def _call_hf_processor( get_best_fit( (image.size[1], image.size[0]), torch.tensor(possible_resolutions), - resize_to_max_canvas=image_processor.resize_to_max_canvas) - for image in parsed_images + resize_to_max_canvas=image_processor.resize_to_max_canvas, + ) for image in parsed_images ] # TODO tile height/width do not necessarily need to match aspect_ratios = [(image_size[0] // tile_size, @@ -659,13 +721,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = ( + vllm_config.parallel_config.enable_vision_encoder_data_parallel) self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_model = Llama4VisionModel(config.vision_config, - None, - prefix=maybe_prefix( - prefix, "vision_model")) + self.vision_model = Llama4VisionModel( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) self.multi_modal_projector = Llama4MultiModalProjector( self.config, None, @@ -709,7 +775,24 @@ def _process_image_input( flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() - vision_embeddings_flat = self.vision_model(flat_data) + # shard image input + if self.use_data_parallel: + num_chunks = flat_data.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + chunk_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + pad = (0, 0, 0, 0, 0, 0, 0, + chunk_per_rank * mp_world_size - num_chunks) + flat_data_padded = torch.nn.functional.pad(flat_data, pad) + rank = get_tensor_model_parallel_rank() + data_per_rank = flat_data_padded[rank * chunk_per_rank:(rank + 1) * + chunk_per_rank, ...].clone() + vision_embeddings_flat = self.vision_model(data_per_rank) + vision_embeddings_flat = tensor_model_parallel_all_gather( + vision_embeddings_flat, dim=0) + vision_embeddings_flat = vision_embeddings_flat[:num_chunks, ...] + else: + vision_embeddings_flat = self.vision_model(flat_data) + vision_embeddings_flat = self.multi_modal_projector( vision_embeddings_flat) @@ -796,6 +879,30 @@ def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]: return get_prefix_weights(), get_other_weights() + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -818,9 +925,12 @@ def load_weights(self, weights: Iterable[tuple[str, assert loaded_language_model_params is not None updated_params.update(loaded_language_model_params) + if self.use_data_parallel: + other_weights = self._consolidate_qkv_weights(other_weights) + for name, loaded_weight in other_weights: for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] From c82c621cc2efb11cc6e7d0f1e7f44a60a7fd9238 Mon Sep 17 00:00:00 2001 From: yZhen Date: Mon, 19 May 2025 17:02:04 -0700 Subject: [PATCH 2/5] Fix linter Signed-off-by: yzhen --- vllm/config.py | 3 ++- vllm/engine/arg_utils.py | 3 ++- vllm/model_executor/models/mllama4.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 03e9b0ce425c..a2ed4d1f5d45 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1791,7 +1791,8 @@ class is dynamically inherited by the worker class. This is used to inject """Global rank in distributed setup.""" enable_vision_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. Only support LLama4 for now""" + """ Use data parallelism instead of tensor parallelism for vision encoder. + Only support LLama4 for now""" @property def world_size_across_dp(self) -> int: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c94283327b55..16db231fdc7a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -423,7 +423,8 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_vision_encoder_data_parallel: bool = ParallelConfig.enable_vision_encoder_data_parallel + enable_vision_encoder_data_parallel: bool = \ + ParallelConfig.enable_vision_encoder_data_parallel def __post_init__(self): # support `EngineArgs(compilation_config={...})` diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index e12ceb304688..bc237f422ef3 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -99,7 +99,8 @@ def __init__( use_data_parallel: bool = False, ): super().__init__() - cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) self.fc1 = cls_fc1( input_size=input_size, output_size=intermediate_size, From 4313092500068153260ff44e7518d75cf845c555 Mon Sep 17 00:00:00 2001 From: yZhen Date: Fri, 30 May 2025 12:12:43 -0700 Subject: [PATCH 3/5] address comment Signed-off-by: yzhen --- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 12 +++++------ vllm/model_executor/models/mllama4.py | 21 ++++--------------- vllm/multimodal/utils.py | 30 +++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a2ed4d1f5d45..d7bae52cb721 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1790,7 +1790,7 @@ class is dynamically inherited by the worker class. This is used to inject rank: int = 0 """Global rank in distributed setup.""" - enable_vision_encoder_data_parallel: bool = False + enable_multimodal_encoder_data_parallel: bool = False """ Use data parallelism instead of tensor parallelism for vision encoder. Only support LLama4 for now""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16db231fdc7a..299c8347f458 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -423,8 +423,8 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_vision_encoder_data_parallel: bool = \ - ParallelConfig.enable_vision_encoder_data_parallel + enable_multimodal_encoder_data_parallel: bool = \ + ParallelConfig.enable_multimodal_encoder_data_parallel def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -641,8 +641,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument("--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( - "--enable-vision-encoder-data-parallel", - **parallel_kwargs["enable_vision_encoder_data_parallel"]) + "--enable-multimodal-encoder-data-parallel", + **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -1084,8 +1084,8 @@ def create_engine_config( distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_vision_encoder_data_parallel=self. - enable_vision_encoder_data_parallel, + enable_multimodal_encoder_data_parallel=self. + enable_multimodal_encoder_data_parallel, ) speculative_config = self.create_speculative_config( diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index bc237f422ef3..f91b58a72aec 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -30,9 +30,7 @@ from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -52,6 +50,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -723,7 +722,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.use_data_parallel = ( - vllm_config.parallel_config.enable_vision_encoder_data_parallel) + vllm_config.parallel_config.enable_multimodal_encoder_data_parallel) self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config @@ -778,19 +777,7 @@ def _process_image_input( # shard image input if self.use_data_parallel: - num_chunks = flat_data.shape[0] - mp_world_size = get_tensor_model_parallel_world_size() - chunk_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size - pad = (0, 0, 0, 0, 0, 0, 0, - chunk_per_rank * mp_world_size - num_chunks) - flat_data_padded = torch.nn.functional.pad(flat_data, pad) - rank = get_tensor_model_parallel_rank() - data_per_rank = flat_data_padded[rank * chunk_per_rank:(rank + 1) * - chunk_per_rank, ...].clone() - vision_embeddings_flat = self.vision_model(data_per_rank) - vision_embeddings_flat = tensor_model_parallel_all_gather( - vision_embeddings_flat, dim=0) - vision_embeddings_flat = vision_embeddings_flat[:num_chunks, ...] + vision_embeddings_flat = run_dp_sharded_vision_model(flat_data, self.vision_model) else: vision_embeddings_flat = self.vision_model(flat_data) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 9ddba67bff70..51b996aef6df 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather from .audio import AudioMediaIO from .base import MediaIO @@ -390,3 +391,32 @@ def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: return [ list(group) for _, group in groupby(mm_inputs, key=modality_group_func) ] + +def run_dp_sharded_vision_model(image_input: torch.Tensor, + vision_model: torch.nn.Module) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function will + shard the input image tensor on the first dimension and run the vision model + + Args: + image_input (torch.Tensor): Image input tensor. + vision_model (torch.nn.Module): Vision model. + + Returns: + torch.Tensor: Output image embeddings + """ + + num_chunks = image_input.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks + pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + image_input_padded = torch.nn.functional.pad(image_input, pad) + rank = get_tensor_model_parallel_rank() + image_input_per_rank = image_input_padded[rank * num_chunks_per_rank:(rank + 1) * + num_chunks_per_rank, ...].clone() + + vision_embeddings = vision_model(image_input_per_rank) + vision_embeddings = tensor_model_parallel_all_gather( + vision_embeddings, dim=0) + vision_embeddings = vision_embeddings[:num_chunks, ...] + return vision_embeddings From bbdb10c8c046c281ff30a09706a6d47505553da7 Mon Sep 17 00:00:00 2001 From: yZhen Date: Fri, 30 May 2025 13:43:47 -0700 Subject: [PATCH 4/5] fix lint Signed-off-by: yzhen --- vllm/model_executor/models/mllama4.py | 7 ++++--- vllm/multimodal/utils.py | 24 +++++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index f91b58a72aec..58549b10e966 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -721,8 +721,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = ( - vllm_config.parallel_config.enable_multimodal_encoder_data_parallel) + self.use_data_parallel = (vllm_config.parallel_config. + enable_multimodal_encoder_data_parallel) self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config @@ -777,7 +777,8 @@ def _process_image_input( # shard image input if self.use_data_parallel: - vision_embeddings_flat = run_dp_sharded_vision_model(flat_data, self.vision_model) + vision_embeddings_flat = run_dp_sharded_vision_model( + flat_data, self.vision_model) else: vision_embeddings_flat = self.vision_model(flat_data) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 51b996aef6df..32c70996d1ee 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -12,7 +12,9 @@ import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection -from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from .audio import AudioMediaIO from .base import MediaIO @@ -392,10 +394,12 @@ def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: list(group) for _, group in groupby(mm_inputs, key=modality_group_func) ] + def run_dp_sharded_vision_model(image_input: torch.Tensor, vision_model: torch.nn.Module) -> torch.Tensor: - """Run a vision model with data parallelism (DP) sharding. The function will - shard the input image tensor on the first dimension and run the vision model + """Run a vision model with data parallelism (DP) sharding. The function + will shard the input image tensor on the first dimension and run the vision + model Args: image_input (torch.Tensor): Image input tensor. @@ -409,14 +413,16 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, mp_world_size = get_tensor_model_parallel_world_size() num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks - pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) image_input_padded = torch.nn.functional.pad(image_input, pad) rank = get_tensor_model_parallel_rank() - image_input_per_rank = image_input_padded[rank * num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, ...].clone() - + image_input_per_rank = image_input_padded[rank * + num_chunks_per_rank:(rank + 1) * + num_chunks_per_rank, + ...].clone() + vision_embeddings = vision_model(image_input_per_rank) - vision_embeddings = tensor_model_parallel_all_gather( - vision_embeddings, dim=0) + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, + dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] return vision_embeddings From 8349179bfce68867eaef42d86ee065786ae60585 Mon Sep 17 00:00:00 2001 From: yzhen Date: Sun, 1 Jun 2025 22:25:53 -0700 Subject: [PATCH 5/5] remove clone Signed-off-by: yzhen --- vllm/multimodal/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 32c70996d1ee..1d838f66f1de 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -418,8 +418,7 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, rank = get_tensor_model_parallel_rank() image_input_per_rank = image_input_padded[rank * num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, - ...].clone() + num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,