diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 310d3a3719b6..8744bcbd3a2a 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,10 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math import pytest import torch +import torch.multiprocessing as mp -from vllm.model_executor.models.vision import resolve_visual_encoder_outputs +from tests.utils import multi_gpu_test +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.models.vision import ( + get_load_balance_assignment, resolve_visual_encoder_outputs, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) +from vllm.platforms import current_platform +from vllm.utils import get_open_port, update_environment_variables @pytest.mark.parametrize( @@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers, post_layer_norm=None, max_possible_layers=max_possible_layers) assert torch.equal(torch.tensor(expected_features), output_tensor) + + +class SimpleLinearModel(torch.nn.Module): + """A simple linear vision model for testing.""" + + def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor): + # Flatten the input and apply linear transformation + x = self.flatten(x) + return self.linear(x) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 4, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, + batch_size: int, master_port: int): + """ + Test that run_dp_sharded_vision_model produces the same results as + calling the model directly. + """ + + # Set random seed for reproducibility + current_platform.seed_everything(0) + + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create a test input tensor + image_input = torch.randn(batch_size, 3, 224, 224) + + # Create a simple linear model + vision_model = SimpleLinearModel() + + # Run the model directly on the full input + with torch.inference_mode(): + direct_output = vision_model(image_input) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_vision_model(image_input, vision_model) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index e1e8282dd66d..f36d94ca0155 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -10,22 +9,11 @@ import numpy as np import pytest -import torch -import torch.multiprocessing as mp from PIL import Image, ImageChops -from tests.utils import multi_gpu_test -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, - get_load_balance_assignment, - run_dp_sharded_mrope_vision_model, - run_dp_sharded_vision_model) -from vllm.platforms import current_platform -from vllm.utils import get_open_port, update_environment_variables +from vllm.multimodal.utils import MediaConnector, argsort_mm_positions if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalPlaceholderDict @@ -404,415 +392,3 @@ def test_argsort_mm_positions(): modality_idxs = argsort_mm_positions(mm_positions) assert modality_idxs == expected_modality_idxs - - -class SimpleLinearModel(torch.nn.Module): - """A simple linear vision model for testing.""" - - def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): - super().__init__() - self.flatten = torch.nn.Flatten() - self.linear = torch.nn.Linear(input_dim, output_dim) - - def forward(self, x: torch.Tensor): - # Flatten the input and apply linear transformation - x = self.flatten(x) - return self.linear(x) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 4, # Small batch - 5, # Odd batch size (for testing padding) - ], -) -def test_run_dp_sharded_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, - ) - - -def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, - batch_size: int, master_port: int): - """ - Test that run_dp_sharded_vision_model produces the same results as - calling the model directly. - """ - - # Set random seed for reproducibility - current_platform.seed_everything(0) - - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create a test input tensor - image_input = torch.randn(batch_size, 3, 224, 224) - - # Create a simple linear model - vision_model = SimpleLinearModel() - - # Run the model directly on the full input - with torch.inference_mode(): - direct_output = vision_model(image_input) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_vision_model(image_input, vision_model) - - # Check that the world size is set up correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize( - "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," - "expected_grouped_sizes_per_gpu,test_description", - [ - # Empty input - ([], 2, [], [0, 0], [0, 0], "empty input"), - - # Fewer samples than GPUs - ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 - ], "fewer samples than GPUs"), - - # Single GPU - ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), - - # Balanced assignment - ([100, 100, 100, 100 - ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), - - # Unbalanced sizes - this one is trickier since the algorithm is greedy - ([1000, 100, 200, 50], 2, [0, 2, 1, 3 - ], [1, 3], [1000, 350], "unbalanced sizes"), - ], -) -def test_get_load_balance_assignment_cases(sizes, num_gpus, - expected_shuffle_indices, - expected_gpu_sample_counts, - expected_grouped_sizes_per_gpu, - test_description): - """Test get_load_balance_assignment with various input cases.""" - result = get_load_balance_assignment(sizes, num_gpus=num_gpus) - (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result - - # Common assertions for all cases - assert len(shuffle_indices) == len(sizes) - assert len(gpu_sample_counts) == num_gpus - assert len(grouped_sizes_per_gpu) == num_gpus - assert sum(gpu_sample_counts) == len(sizes) - - assert shuffle_indices == expected_shuffle_indices - - assert gpu_sample_counts == expected_gpu_sample_counts - assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu - - -class SimpleMRopeVisionModel(torch.nn.Module): - """A simple vision model for testing mrope functionality.""" - - def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): - super().__init__() - self.spatial_merge_size = spatial_merge_size - self.out_hidden_size = out_hidden_size - self.linear = torch.nn.Linear(768, out_hidden_size) - - def forward(self, pixel_values: torch.Tensor, - grid_thw_list: list[list[int]]): - """Simple forward pass that simulates spatial merging.""" - # Apply linear transformation - embeddings = self.linear(pixel_values) - - # Simulate spatial merging by reducing the number of patches - merge_factor = self.spatial_merge_size * self.spatial_merge_size - - # Group patches and merge spatially - merged_embeddings = [] - start_idx = 0 - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - end_idx = start_idx + num_patches - - # Get patches for this image - image_patches = embeddings[start_idx:end_idx] - - # Simulate spatial merging by averaging groups of patches - merged_patches = num_patches // merge_factor - if merged_patches > 0: - # Reshape and average to simulate merging - reshaped = image_patches[:merged_patches * merge_factor].view( - merged_patches, merge_factor, -1) - merged = reshaped.mean(dim=1) - merged_embeddings.append(merged) - - start_idx = end_idx - - if merged_embeddings: - return torch.cat(merged_embeddings, dim=0) - else: - return torch.empty((0, self.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 3, # Small batch - 5, # Odd batch size (for testing padding) - ], -) -def test_run_dp_sharded_mrope_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_mrope_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, - world_size: int, - batch_size: int, - master_port: int): - """ - Test that run_dp_sharded_mrope_vision_model produces the same results as - calling the model directly. - """ - # Set random seed for reproducibility - current_platform.seed_everything(0) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test data - grid_thw_list = [] - pixel_values_list = [] - - for i in range(batch_size): - # Varying image sizes for better testing - t, h, w = 1, 4 + i, 4 + i - grid_thw_list.append([t, h, w]) - - num_patches = t * h * w - # Create random pixel values for this image - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - # Concatenate all pixel values - pixel_values = torch.cat(pixel_values_list, dim=0) - - # Create a simple mrope vision model - vision_model = SimpleMRopeVisionModel() - - # Run the model directly on the full input (only on rank 0) - if local_rank == 0: - with torch.inference_mode(): - direct_output = vision_model(pixel_values, grid_thw_list) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - sharded_output = torch.cat(sharded_output, dim=0) - - # Check that the world size is set up correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Compare outputs (only on rank 0) - if local_rank == 0: - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, - sharded_output, - rtol=1e-5, - atol=1e-5) - - -@multi_gpu_test(num_gpus=2) -def test_run_dp_sharded_mrope_vision_model_empty_input(): - world_size = 2 - mp.spawn( - run_dp_sharded_mrope_vision_model_empty_input_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_empty_input_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with empty input.""" - # Set up distributed environment - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create empty inputs - pixel_values = torch.empty((0, 768)) - grid_thw_list: list[list[int]] = [] - - vision_model = SimpleMRopeVisionModel() - - # Should handle empty input gracefully - with torch.inference_mode(): - output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - - assert len(output) == 0 - - -@multi_gpu_test(num_gpus=4) -def test_run_dp_sharded_mrope_vision_model_uneven_load(): - world_size = 4 - mp.spawn( - run_dp_sharded_mrope_vision_model_uneven_load_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_uneven_load_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" - # Set up distributed environment - current_platform.seed_everything(123) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create images with very different sizes - grid_thw_list = [ - [1, 2, 2], # Small: 4 patches - [1, 8, 8], # Large: 64 patches - [1, 3, 3], # Medium: 9 patches - ] - - pixel_values_list = [] - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel() - - # Should handle uneven distribution without errors - with torch.inference_mode(): - output_tuple = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - - # Verify output shape is reasonable - merge_factor = vision_model.spatial_merge_size**2 - expected_output_patches = list( - math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) - - for i, output in enumerate(output_tuple): - assert output.shape[0] == expected_output_patches[i] - assert output.shape[1] == vision_model.out_hidden_size - - -@pytest.mark.parametrize("spatial_merge_size", [2, 4]) -def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): - """Test SimpleMRopeVisionModel with different spatial merge sizes.""" - device = current_platform.device_type - - grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images - pixel_values_list = [] - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768, device=device) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel( - spatial_merge_size=spatial_merge_size).to(device) - - with torch.inference_mode(): - output = vision_model(pixel_values, grid_thw_list) - - # Verify output dimensions based on spatial merging - total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) - merge_factor = spatial_merge_size**2 - expected_output_patches = total_patches // merge_factor - - assert output.shape[0] == expected_output_patches - assert output.shape[1] == vision_model.out_hidden_size diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 56ec63438690..b088e0c0dd24 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -69,7 +69,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -83,7 +82,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 76737a442823..2f0c4240413b 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -34,7 +34,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import run_dp_sharded_vision_model + +from .vision import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 892188c04722..2c341d283971 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -28,7 +28,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import run_dp_sharded_vision_model + +from .vision import run_dp_sharded_vision_model NORM2FN = { 'rms_norm': RMSNorm, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index f554077935bf..503627865c4a 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -76,13 +76,13 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .vision import run_dp_sharded_mrope_vision_model # For dummy input only diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 131a66b71323..50521b593786 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -50,7 +50,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -58,6 +57,7 @@ from .llama4 import Llama4ForCausalLM from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) +from .vision import run_dp_sharded_vision_model class Llama4ImagePatchInputs(TensorSchema): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 73b27572a8eb..b740e6d87b74 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -59,7 +59,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -74,7 +73,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 88813490c0fb..472e8b061a9e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -66,7 +66,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -78,7 +77,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 98d65dea2739..ee6703f7229e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -83,7 +83,7 @@ from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -1214,8 +1214,6 @@ def _process_image_input( else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: - from vllm.multimodal.utils import ( - run_dp_sharded_mrope_vision_model) return run_dp_sharded_mrope_vision_model(self.visual, pixel_values, grid_thw_list, @@ -1245,8 +1243,6 @@ def _process_video_input( pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) if self.use_data_parallel: - from vllm.multimodal.utils import ( - run_dp_sharded_mrope_vision_model) return run_dp_sharded_mrope_vision_model(self.visual, pixel_values_videos, grid_thw_list, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f667266b77bf..5f6ad5885043 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -31,7 +31,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -40,6 +39,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import run_dp_sharded_vision_model class Step3VLImagePixelInputs(TypedDict): diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 81f86db7e187..08ad8fbeb424 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +import math from abc import ABC, abstractmethod -from typing import Final, Generic, Optional, Protocol, TypeVar, Union +from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs( if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) return torch.cat(hs_pool, dim=-1) + + +def run_dp_sharded_vision_model(image_input: torch.Tensor, + vision_model: torch.nn.Module) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function + will shard the input image tensor on the first dimension and run the vision + model + + Args: + image_input (torch.Tensor): Image input tensor. + vision_model (torch.nn.Module): Vision model. + Returns: + torch.Tensor: Output image embeddings + """ + + num_chunks = image_input.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks + pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + image_input_padded = torch.nn.functional.pad(image_input, pad) + rank = get_tensor_model_parallel_rank() + image_input_per_rank = image_input_padded[rank * + num_chunks_per_rank:(rank + 1) * + num_chunks_per_rank, ...] + + vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, + dim=0) + vision_embeddings = vision_embeddings[:num_chunks, ...] + return vision_embeddings + + +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], + *, + rope_type: Literal["rope_3d", "rope_2d"], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + if rope_type == "rope_2d": + embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * + vision_model.merge_kernel_size[1]) + else: + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list)) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype) + else: + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + if rope_type == "rope_2d": + padding = torch.empty((padding_size, image_embeds_local.shape[1], + image_embeds_local.shape[2]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + else: + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index f4e2ed72e2d7..0f8aeceb3944 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,13 +3,11 @@ import asyncio import atexit -import itertools -import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse from urllib.request import url2pathname @@ -21,9 +19,6 @@ import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) from .audio import AudioMediaIO from .base import MediaIO @@ -33,12 +28,10 @@ _M = TypeVar("_M") if TYPE_CHECKING: - from .inputs import (BatchedTensorInputs, MultiModalKwargs, - MultiModalKwargsItem, MultiModalKwargsItems, - MultiModalPlaceholderDict) + from .inputs import (BatchedTensorInputs, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalPlaceholderDict) else: BatchedTensorInputs = Any - MultiModalKwargs = Any MultiModalKwargsItem = Any MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any @@ -93,7 +86,7 @@ def _load_data_url( self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] data_spec, data = url_spec.path.split(",", 1) media_type, data_type = data_spec.split(";", 1) @@ -107,7 +100,7 @@ def _load_file_url( self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] allowed_local_media_path = self.allowed_local_media_path if allowed_local_media_path is None: raise RuntimeError("Cannot load local files without " @@ -127,7 +120,7 @@ def load_from_url( media_io: MediaIO[_M], *, fetch_timeout: Optional[int] = None, - ) -> _M: + ) -> _M: # type: ignore[type-var] url_spec = urlparse(url) if url_spec.scheme.startswith("http"): @@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality( yield modality, len(items_lst), mm_kwargs_group -def run_dp_sharded_vision_model(image_input: torch.Tensor, - vision_model: torch.nn.Module) -> torch.Tensor: - """Run a vision model with data parallelism (DP) sharding. The function - will shard the input image tensor on the first dimension and run the vision - model - - Args: - image_input (torch.Tensor): Image input tensor. - vision_model (torch.nn.Module): Vision model. - Returns: - torch.Tensor: Output image embeddings - """ - - num_chunks = image_input.shape[0] - mp_world_size = get_tensor_model_parallel_world_size() - num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size - num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks - pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) - image_input_padded = torch.nn.functional.pad(image_input, pad) - rank = get_tensor_model_parallel_rank() - image_input_per_rank = image_input_padded[rank * - num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, ...] - - vision_embeddings = vision_model(image_input_per_rank) - # Ensure tensor is contiguous before all_gather - vision_embeddings = vision_embeddings.contiguous() - vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, - dim=0) - vision_embeddings = vision_embeddings[:num_chunks, ...] - return vision_embeddings - - -def get_load_balance_assignment( - sizes: list[int], - num_gpus: int = 2, -) -> tuple[list[int], list[int], list[int]]: - """ - Generate load balancing assignment and metadata - for distributing data across GPUs. - The load is determined by the total image sizes, - not the number of images. - - Args: - sizes: The size of each image - num_gpus: Number of GPUs to balance across - - Returns: - shuffle_indices: - Indices to reorder data for balanced loading - gpu_sample_counts: - Number of samples assigned to each GPU - grouped_sizes_per_gpu: - Total size assigned to each GPU - - Example: - ``` - sizes = [1000, 100, 200, 50] - num_gpus=2 - ``` - - """ - - n_samples = len(sizes) - - # Handle edge cases - if n_samples == 0: - return [], [0] * num_gpus, [0] * num_gpus - - # Use greedy algorithm - balance by total size, not sample count - gpu_assignments = [list[int]() for _ in range(num_gpus)] - gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count - - # Sort indices by size (largest first for better load balancing) - # sizes = [1000, 100, 200, 50] - # large_to_small_indices = [0, 2, 1, 3] - large_to_small_indices = sorted(range(n_samples), - key=lambda i: sizes[i], - reverse=True) - - for idx in large_to_small_indices: - # Find GPU with minimum current load (by total size) - min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) - gpu_assignments[min_gpu].append(idx) - gpu_loads[min_gpu] += sizes[idx] - - # Create shuffle indices and counts - shuffle_indices = list[int]() - gpu_sample_counts = list[int]() - for gpu_id in range(num_gpus): - # GPU_0 = [1000] = [0] - # GPU_1 = [200, 100, 50] = [2, 1, 3] - # shuffle_indices = [0, 2, 1, 3] - shuffle_indices.extend(gpu_assignments[gpu_id]) - # GPU_0 = [1] - # GPU_1 = [3] - # gpu_sample_counts = [1, 3] - gpu_sample_counts.append(len(gpu_assignments[gpu_id])) - - return (shuffle_indices, gpu_sample_counts, gpu_loads) - - -def run_dp_sharded_mrope_vision_model( - vision_model: torch.nn.Module, - pixel_values: torch.Tensor, - grid_thw_list: list[list[int]], - *, - rope_type: Literal["rope_3d", "rope_2d"], -) -> tuple[torch.Tensor, ...]: - """Run a vision model with data parallelism (DP) sharding. - The function will shard the input image tensor on the - first dimension and run the vision model. - This function is used to run the vision model with mrope. - - Args: - vision_model (torch.nn.Module): Vision model. - pixel_values (torch.Tensor): Image/Video input tensor. - grid_thw_list: List of grid dimensions for each image - rope_type: Type of rope used in the vision model. - Different rope types have different dimension to do ViT. - "rope_3d" for 3D rope (e.g., Qwen2.5-VL) - "rope_2d" for 2D rope (e.g., Kimi-VL) - Returns: - torch.Tensor: Output image embeddings - - Example: - ``` - vision_model.out_hidden_size = 64 - vision_model.spatial_merge_size = 2 - pixel_values.shape = (1350, channel) - grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] - tp_size=2 - ``` - - """ - tp_size = get_tensor_model_parallel_world_size() - - # GPU_0 tp_rank_local = 0 - # GPU_1 tp_rank_local = 1 - tp_rank_local = get_tensor_model_parallel_rank() - - # patches_per_image = [1000, 100, 200, 50] - patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] - # patches_per_image = [0, 1000, 1100, 1300, 1350] - cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] - - # Get load balancing assignment with all metadata - # image_to_tp_rank = [0, 2, 1, 3] - # gpu_sample_counts = [1, 3] - # grouped_pixel_values_len = [1000, 350] - (image_to_tp_rank, gpu_sample_counts, - grouped_pixel_values_len) = get_load_balance_assignment( - patches_per_image, tp_size) - - # cu_gpu_sample_counts = [0, 1, 4] - cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] - - # GPU_0 image_idxs_local = [0] - # GPU_1 image_idxs_local = [2, 1, 3] - image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: - cum_gpu_sample_counts[tp_rank_local + - 1]] - - # Get the pixel values for the local images based on the image_idxs_local - if len(image_idxs_local) > 0: - pixel_values_local = torch.cat([ - pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] - for i in image_idxs_local - ]) - else: - # Handle case where this rank has no images - pixel_values_local = torch.empty((0, pixel_values.shape[1]), - device=pixel_values.device, - dtype=pixel_values.dtype) - # embed_dim_reduction_factor = 2 * 2 - if rope_type == "rope_2d": - embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * - vision_model.merge_kernel_size[1]) - else: - embed_dim_reduction_factor = (vision_model.spatial_merge_size * - vision_model.spatial_merge_size) - - # Find the max length across all ranks - # The output embedding of every DP rank has to be - # padded to this length for tensor_model_parallel_all_gather - # to work - max_len_per_rank = max( - grouped_pixel_values_len) // embed_dim_reduction_factor - local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] - - # Run the vision model on the local pixel_values_local - if rope_type == "rope_2d": - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model( - pixel_values_local, torch.tensor(local_grid_thw_list)) - if isinstance(image_embeds_local, list): - image_embeds_local = torch.cat(image_embeds_local, dim=0) - else: - out_dim = getattr(vision_model.config, "hidden_size", None) - image_embeds_local = torch.empty( - (0, embed_dim_reduction_factor, out_dim), - device=pixel_values.device, - dtype=pixel_values.dtype) - else: - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, - local_grid_thw_list) - else: - # Handle empty case - image_embeds_local = torch.empty((0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - # Pad the output based on max_len_per_rank - # for tensor_model_parallel_all_gather to work - current_len = image_embeds_local.shape[0] - if current_len < max_len_per_rank: - padding_size = max_len_per_rank - current_len - if rope_type == "rope_2d": - padding = torch.empty((padding_size, image_embeds_local.shape[1], - image_embeds_local.shape[2]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - else: - padding = torch.empty((padding_size, image_embeds_local.shape[1]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - image_embeds_local_padded = torch.cat([image_embeds_local, padding], - dim=0) - else: - image_embeds_local_padded = image_embeds_local - - # Do all_gather to collect embeddings from all ranks - gathered_embeds = tensor_model_parallel_all_gather( - image_embeds_local_padded, dim=0) - - # Remove padding and reconstruct per-rank embeddings - rank_embeddings = list[torch.Tensor]() - for rank in range(tp_size): - start_idx = rank * max_len_per_rank - end_idx = start_idx + (grouped_pixel_values_len[rank] // - embed_dim_reduction_factor) - rank_embeddings.append(gathered_embeds[start_idx:end_idx]) - - patches_per_output_image = [(patch_size // embed_dim_reduction_factor) - for patch_size in patches_per_image] - - # Reconstruct embeddings in the original order - original_order_embeddings = [None] * len(grid_thw_list) - current_idx = 0 - for rank in range(tp_size): - count = gpu_sample_counts[rank] - if count > 0: - # Get images assigned to this rank in shuffled order - # GPU_0 = image_idxs_local [0] - # GPU_1 = image_idxs_local [2, 1, 3] - rank_images = image_to_tp_rank[current_idx:current_idx + count] - - rank_embed = rank_embeddings[rank] - # Split rank embeddings back to individual images - embed_start = 0 - for img_idx in rank_images: - img_patches = patches_per_output_image[img_idx] - original_order_embeddings[img_idx] = rank_embed[ - embed_start:embed_start + img_patches] - embed_start += img_patches - current_idx += count - out_embeddings = tuple(embed for embed in original_order_embeddings - if embed is not None) - assert len(out_embeddings) == len( - original_order_embeddings), "Found unassigned embeddings" - return out_embeddings - - def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None,