| 
 | 1 | +# SPDX-License-Identifier: Apache-2.0  | 
 | 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project  | 
 | 3 | +import random  | 
 | 4 | +from functools import reduce  | 
 | 5 | + | 
 | 6 | +import pytest  | 
 | 7 | +import torch  | 
 | 8 | +import torch.multiprocessing as mp  | 
 | 9 | + | 
 | 10 | +from tests.utils import multi_gpu_test  | 
 | 11 | +from vllm.distributed.parallel_state import (init_distributed_environment,  | 
 | 12 | +                                             initialize_model_parallel)  | 
 | 13 | +from vllm.model_executor.layers.batch_invariant import init_batch_invariance  | 
 | 14 | +from vllm.model_executor.layers.layernorm import RMSNorm  | 
 | 15 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear,  | 
 | 16 | +                                               RowParallelLinear)  | 
 | 17 | +from vllm.platforms import current_platform  | 
 | 18 | +from vllm.utils import update_environment_variables  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +def get_open_port():  | 
 | 22 | +    """Get an available port for distributed testing."""  | 
 | 23 | +    import socket  | 
 | 24 | +    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:  | 
 | 25 | +        s.bind(('', 0))  | 
 | 26 | +        return s.getsockname()[1]  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +def run_parallel_op_test_worker(local_rank: int, world_size: int,  | 
 | 30 | +                                master_port: int, test_config: dict, fn):  | 
 | 31 | +    """Worker function that runs on each GPU process."""  | 
 | 32 | +    # Set up distributed environment  | 
 | 33 | +    device = f"cuda:{local_rank}"  | 
 | 34 | +    current_platform.set_device(device)  | 
 | 35 | +    torch.cuda.set_device(device)  | 
 | 36 | +    torch.set_default_device(device)  | 
 | 37 | + | 
 | 38 | +    update_environment_variables({  | 
 | 39 | +        'RANK': str(local_rank),  | 
 | 40 | +        'LOCAL_RANK': str(local_rank),  | 
 | 41 | +        'WORLD_SIZE': str(world_size),  | 
 | 42 | +        'MASTER_ADDR': 'localhost',  | 
 | 43 | +        'MASTER_PORT': str(master_port),  | 
 | 44 | +    })  | 
 | 45 | + | 
 | 46 | +    # Initialize distributed  | 
 | 47 | +    init_distributed_environment()  | 
 | 48 | +    initialize_model_parallel(tensor_model_parallel_size=world_size)  | 
 | 49 | + | 
 | 50 | +    # Set seed for reproducibility  | 
 | 51 | +    current_platform.seed_everything(42)  | 
 | 52 | +    init_batch_invariance()  | 
 | 53 | + | 
 | 54 | +    # Run the specific test based on op_name  | 
 | 55 | +    fn(local_rank, world_size, test_config)  | 
 | 56 | + | 
 | 57 | + | 
 | 58 | +class ULPChecker:  | 
 | 59 | +    FP_SPECS = {  | 
 | 60 | +        torch.float8_e4m3fn: {  | 
 | 61 | +            'mantissa_bits': 3,  | 
 | 62 | +            'exponent_bits': 4,  | 
 | 63 | +            'total_bits': 8,  | 
 | 64 | +            'int_dtype': torch.uint8  | 
 | 65 | +        },  | 
 | 66 | +        torch.float8_e5m2: {  | 
 | 67 | +            'mantissa_bits': 2,  | 
 | 68 | +            'exponent_bits': 5,  | 
 | 69 | +            'total_bits': 8,  | 
 | 70 | +            'int_dtype': torch.uint8  | 
 | 71 | +        },  | 
 | 72 | +        torch.bfloat16: {  | 
 | 73 | +            'mantissa_bits': 7,  | 
 | 74 | +            'exponent_bits': 8,  | 
 | 75 | +            'total_bits': 16,  | 
 | 76 | +            'int_dtype': torch.int16  | 
 | 77 | +        },  | 
 | 78 | +        torch.float16: {  | 
 | 79 | +            'mantissa_bits': 10,  | 
 | 80 | +            'exponent_bits': 5,  | 
 | 81 | +            'total_bits': 16,  | 
 | 82 | +            'int_dtype': torch.int16  | 
 | 83 | +        },  | 
 | 84 | +        torch.float32: {  | 
 | 85 | +            'mantissa_bits': 23,  | 
 | 86 | +            'exponent_bits': 8,  | 
 | 87 | +            'total_bits': 32,  | 
 | 88 | +            'int_dtype': torch.int32  | 
 | 89 | +        },  | 
 | 90 | +        torch.float64: {  | 
 | 91 | +            'mantissa_bits': 52,  | 
 | 92 | +            'exponent_bits': 11,  | 
 | 93 | +            'total_bits': 64,  | 
 | 94 | +            'int_dtype': torch.int64  | 
 | 95 | +        },  | 
 | 96 | +    }  | 
 | 97 | + | 
 | 98 | +    @staticmethod  | 
 | 99 | +    def to_int_bits(tensor: torch.Tensor) -> torch.Tensor:  | 
 | 100 | +        dtype = tensor.dtype  | 
 | 101 | +        if dtype not in ULPChecker.FP_SPECS:  | 
 | 102 | +            raise ValueError(f"Unsupported dtype: {dtype}")  | 
 | 103 | + | 
 | 104 | +        spec = ULPChecker.FP_SPECS[dtype]  | 
 | 105 | +        int_dtype = spec['int_dtype']  | 
 | 106 | + | 
 | 107 | +        return tensor.view(int_dtype)  | 
 | 108 | + | 
 | 109 | +    @staticmethod  | 
 | 110 | +    def ulp_distance_int(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:  | 
 | 111 | +        if a.dtype != b.dtype:  | 
 | 112 | +            raise ValueError(f"Dtype mismatch: {a.dtype} vs {b.dtype}")  | 
 | 113 | + | 
 | 114 | +        if a.shape != b.shape:  | 
 | 115 | +            raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")  | 
 | 116 | + | 
 | 117 | +        spec = ULPChecker.FP_SPECS[a.dtype]  | 
 | 118 | +        total_bits = spec['total_bits']  | 
 | 119 | + | 
 | 120 | +        a_int = ULPChecker.to_int_bits(a)  | 
 | 121 | +        b_int = ULPChecker.to_int_bits(b)  | 
 | 122 | + | 
 | 123 | +        sign_bit = 1 << (total_bits - 1)  | 
 | 124 | + | 
 | 125 | +        a_ordered = torch.where(  | 
 | 126 | +            (a_int & sign_bit) != 0,  | 
 | 127 | +            sign_bit - (a_int & ~sign_bit),  # Negative: flip magnitude bits  | 
 | 128 | +            a_int + sign_bit  # Positive: offset by sign bit  | 
 | 129 | +        )  | 
 | 130 | +        b_ordered = torch.where((b_int & sign_bit) != 0,  | 
 | 131 | +                                sign_bit - (b_int & ~sign_bit),  | 
 | 132 | +                                b_int + sign_bit)  | 
 | 133 | + | 
 | 134 | +        ulp_dist = torch.abs(a_ordered - b_ordered)  | 
 | 135 | +        return ulp_dist  | 
 | 136 | + | 
 | 137 | + | 
 | 138 | +def create_needle_tensor(  | 
 | 139 | +        batch_size: int,  | 
 | 140 | +        shape: list[int],  | 
 | 141 | +        device: torch.device,  | 
 | 142 | +        dtype: torch.dtype,  | 
 | 143 | +        needle_idx: int = 0) -> torch.Tensor:  | 
 | 144 | +    input_tensor = torch.randn(batch_size, *shape, device=device, dtype=dtype)  | 
 | 145 | + | 
 | 146 | +    numel = reduce(lambda x, y: x * y, shape)  | 
 | 147 | +    needle_pattern = torch.sin(  | 
 | 148 | +        torch.arange(numel, device=device).float().view(*shape) *  | 
 | 149 | +        0.1).to(dtype)  | 
 | 150 | + | 
 | 151 | +    assert needle_idx < input_tensor.shape[0]  | 
 | 152 | +    input_tensor[needle_idx] = needle_pattern  | 
 | 153 | + | 
 | 154 | +    return input_tensor  | 
 | 155 | + | 
 | 156 | + | 
 | 157 | +def verify(outputs: list[torch.Tensor],  | 
 | 158 | +                              needle_idxs: list[int]) -> bool:  | 
 | 159 | +    if len(outputs) < 2:  | 
 | 160 | +        return True  | 
 | 161 | + | 
 | 162 | +    needle_outputs = []  | 
 | 163 | +    for output, needle_idx in zip(outputs, needle_idxs):  | 
 | 164 | +        needle_outputs.append(output[needle_idx])  | 
 | 165 | + | 
 | 166 | +    reference = needle_outputs[0]  | 
 | 167 | +    for i, needle_output in enumerate(needle_outputs[1:], 1):  | 
 | 168 | +        dist_t = ULPChecker.ulp_distance_int(reference, needle_output)  | 
 | 169 | +        if torch.max(dist_t) != 0:  | 
 | 170 | +            print(f"Needle consistency failed at batch size comparison {i}")  | 
 | 171 | +            print(f"Max difference (ULP): {torch.max(dist_t)}")  | 
 | 172 | +            print(f"Max difference: {torch.max(reference - needle_output)}")  | 
 | 173 | +            return False  | 
 | 174 | + | 
 | 175 | +    return True  | 
 | 176 | + | 
 | 177 | + | 
 | 178 | +def validate(func, batch_sizes, shape, device, dtype):  | 
 | 179 | +    random.seed(123)  | 
 | 180 | +    outputs = []  | 
 | 181 | +    needle_idxs = []  | 
 | 182 | + | 
 | 183 | +    for batch_size in batch_sizes:  | 
 | 184 | +        needle_idx = random.randint(0, batch_size - 1)  | 
 | 185 | +        input_tensor = create_needle_tensor(batch_size, shape, device, dtype,  | 
 | 186 | +                                            needle_idx)  | 
 | 187 | + | 
 | 188 | +        with torch.no_grad():  | 
 | 189 | +            output = func(input_tensor)  | 
 | 190 | +            assert isinstance(output, torch.Tensor)  | 
 | 191 | +            outputs.append(output)  | 
 | 192 | +            needle_idxs.append(needle_idx)  | 
 | 193 | + | 
 | 194 | +    assert verify(outputs, needle_idxs), \  | 
 | 195 | +        "Needle consistency failed"  | 
 | 196 | + | 
 | 197 | + | 
 | 198 | +def _test_column_parallel_linear(local_rank: int, world_size: int,  | 
 | 199 | +                                 config: dict):  | 
 | 200 | +    device = torch.device(f"cuda:{local_rank}")  | 
 | 201 | +    batch_sizes = [1, 8, 32]  | 
 | 202 | +    dtype = config['dtype']  | 
 | 203 | +    hidden_size = config['reduction_size']  | 
 | 204 | +    seq_len = 4096  | 
 | 205 | +    input_size = hidden_size  | 
 | 206 | +    output_size = hidden_size * 2  | 
 | 207 | +    layer = ColumnParallelLinear(  | 
 | 208 | +        input_size=input_size,  | 
 | 209 | +        output_size=output_size,  | 
 | 210 | +        bias=True,  | 
 | 211 | +        gather_output=False,  | 
 | 212 | +        params_dtype=dtype,  | 
 | 213 | +    )  | 
 | 214 | +    layer = layer.to(device)  | 
 | 215 | +    validate(lambda x: layer(x)[0], batch_sizes, (seq_len, hidden_size),  | 
 | 216 | +             device, dtype)  | 
 | 217 | + | 
 | 218 | + | 
 | 219 | +def _test_row_parallel_linear(local_rank: int, world_size: int, config: dict):  | 
 | 220 | +    device = torch.device(f"cuda:{local_rank}")  | 
 | 221 | +    batch_sizes = [1, 8, 32]  | 
 | 222 | +    dtype = config['dtype']  | 
 | 223 | +    hidden_size = config['reduction_size']  | 
 | 224 | +    seq_len = 4096  | 
 | 225 | +    input_size = hidden_size * 2  | 
 | 226 | +    output_size = hidden_size  | 
 | 227 | +    layer = RowParallelLinear(  | 
 | 228 | +        input_size=input_size,  | 
 | 229 | +        output_size=output_size,  | 
 | 230 | +        bias=True,  | 
 | 231 | +        reduce_results=True,  | 
 | 232 | +        params_dtype=dtype,  | 
 | 233 | +    )  | 
 | 234 | +    layer = layer.to(device)  | 
 | 235 | +    validate(lambda x: layer(x)[0], batch_sizes,  | 
 | 236 | +             (seq_len, input_size // world_size), device, dtype)  | 
 | 237 | + | 
 | 238 | + | 
 | 239 | +def _test_rms_norm(local_rank: int, world_size: int,  | 
 | 240 | +                                      config: dict):  | 
 | 241 | +    """Test RMSNorm with needle consistency."""  | 
 | 242 | +    device = torch.device(f"cuda:{local_rank}")  | 
 | 243 | +    dtype = config['dtype']  | 
 | 244 | +    hidden_size = config['reduction_size']  | 
 | 245 | +    batch_sizes = [1, 32, 1024]  | 
 | 246 | + | 
 | 247 | +    layer = RMSNorm(hidden_size, eps=1e-6)  | 
 | 248 | +    layer = layer.to(device).to(dtype)  | 
 | 249 | +    validate(layer, batch_sizes, (hidden_size, ), device, dtype)  | 
 | 250 | + | 
 | 251 | + | 
 | 252 | +def _test_fused_rms_norm(local_rank: int, world_size: int,  | 
 | 253 | +                                            config: dict):  | 
 | 254 | +    device = torch.device(f"cuda:{local_rank}")  | 
 | 255 | +    dtype = config['dtype']  | 
 | 256 | +    hidden_size = config['reduction_size']  | 
 | 257 | +    batch_sizes = [1, 32, 1024]  | 
 | 258 | + | 
 | 259 | +    layer = RMSNorm(hidden_size, eps=1e-6)  | 
 | 260 | +    layer = layer.to(device).to(dtype)  | 
 | 261 | +    validate(lambda x: layer(x, x)[0], batch_sizes, (hidden_size, ), device,  | 
 | 262 | +             dtype)  | 
 | 263 | + | 
 | 264 | + | 
 | 265 | +def _test_fused_moe(local_rank: int, world_size: int,  | 
 | 266 | +                                       config: dict):  | 
 | 267 | +    """Test FusedMoE with needle consistency."""  | 
 | 268 | +    device = torch.device(f"cuda:{local_rank}")  | 
 | 269 | +    dtype = config['dtype']  | 
 | 270 | +    hidden_size = config['reduction_size']  | 
 | 271 | +    batch_sizes = [1, 8, 32]  | 
 | 272 | + | 
 | 273 | +    # MoE configuration parameters  | 
 | 274 | +    num_experts = 8  | 
 | 275 | +    top_k = 2  | 
 | 276 | +    intermediate_size = hidden_size * 4  | 
 | 277 | + | 
 | 278 | +    from vllm.config import VllmConfig  | 
 | 279 | +    from vllm.forward_context import get_forward_context, set_forward_context  | 
 | 280 | +    from vllm.model_executor.layers.fused_moe import FusedMoE  | 
 | 281 | + | 
 | 282 | +    vllm_config = VllmConfig()  | 
 | 283 | + | 
 | 284 | +    # Create FusedMoE layer similar to how it's used in models  | 
 | 285 | +    layer = FusedMoE(  | 
 | 286 | +        num_experts=num_experts,  | 
 | 287 | +        top_k=top_k,  | 
 | 288 | +        hidden_size=hidden_size,  | 
 | 289 | +        intermediate_size=intermediate_size,  | 
 | 290 | +        params_dtype=dtype,  | 
 | 291 | +        reduce_results=True,  | 
 | 292 | +        renormalize=True,  | 
 | 293 | +        use_grouped_topk=False,  | 
 | 294 | +    )  | 
 | 295 | +    layer = layer.to(device)  | 
 | 296 | + | 
 | 297 | +    # Test function that takes hidden states and generates router logits  | 
 | 298 | +    def test_func(hidden_states):  | 
 | 299 | +        # Generate router logits (this would normally come from a router layer)  | 
 | 300 | +        router_logits = torch.randn(hidden_states.shape[0],  | 
 | 301 | +                                    hidden_states.shape[1],  | 
 | 302 | +                                    num_experts,  | 
 | 303 | +                                    device=device,  | 
 | 304 | +                                    dtype=dtype)  | 
 | 305 | + | 
 | 306 | +        # Set forward context with minimal required parameters  | 
 | 307 | +        # attn_metadata can be None for testing purposes  | 
 | 308 | +        with set_forward_context(attn_metadata=None,  | 
 | 309 | +                                 vllm_config=vllm_config,  | 
 | 310 | +                                 num_tokens=hidden_states.shape[0] *  | 
 | 311 | +                                 hidden_states.shape[1]):  | 
 | 312 | +            fwdctx = get_forward_context()  | 
 | 313 | +            fwdctx.no_compile_layers[''] = layer  | 
 | 314 | +            return layer(hidden_states, router_logits)  | 
 | 315 | + | 
 | 316 | +    validate(test_func, batch_sizes, (hidden_size, ), device, dtype)  | 
 | 317 | + | 
 | 318 | + | 
 | 319 | +@multi_gpu_test(num_gpus=2)  | 
 | 320 | +@pytest.mark.parametrize("world_size", [2])  | 
 | 321 | +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])  | 
 | 322 | +@pytest.mark.parametrize("reduction_size", [1, 5, 1024, 1024 + 1])  | 
 | 323 | +@pytest.mark.parametrize("func", [  | 
 | 324 | +    _test_column_parallel_linear,  | 
 | 325 | +    _test_row_parallel_linear,  | 
 | 326 | +    _test_rms_norm,  | 
 | 327 | +    _test_fused_rms_norm,  | 
 | 328 | +    _test_fused_moe,  | 
 | 329 | +])  | 
 | 330 | +def test_parallel_reduction_batch_invariance(world_size: int,  | 
 | 331 | +                                             dtype: torch.dtype,  | 
 | 332 | +                                             reduction_size: int, func):  | 
 | 333 | +    """Test parallel operators on 2 GPUs."""  | 
 | 334 | +    test_config = {  | 
 | 335 | +        "dtype": dtype,  | 
 | 336 | +        "reduction_size": reduction_size,  | 
 | 337 | +    }  | 
 | 338 | + | 
 | 339 | +    mp.spawn(run_parallel_op_test_worker,  | 
 | 340 | +             args=(world_size, get_open_port(), test_config, func),  | 
 | 341 | +             nprocs=world_size)  | 
0 commit comments