diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6cbc25b4b3bf..105eca371ff3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -412,6 +412,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..0d1ec49e3f41 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time + +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def standard_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + norm_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm operations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Then RMS norm + if residual is not None: + # Fused add + RMS norm + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + else: + # Just RMS norm + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + + +def standard_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm + FP8 quantization.""" + if quant_out is None: + quant_out = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then fused RMS norm + FP8 quantization + if residual is not None: + FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, residual, rms_gamma, scale_factor, rms_eps + ) + return quant_out, residual + else: + RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm + FP4 quantization.""" + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + if residual is not None: + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + quant_input = allreduce_out + residual_out = residual + else: + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + quant_input = norm_out + residual_out = allreduce_out + + # Finally FP4 quantization + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +def standard_allreduce_rmsnorm_native( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + norm_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Apply native RMSNorm + if residual is not None: + result = rmsnorm_layer.forward_native(allreduce_out, residual) + return result # Returns (norm_out, residual_out) + else: + result = rmsnorm_layer.forward_native(allreduce_out) + return result # Returns norm_out + + +def standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + residual_out = allreduce_out + + # Apply native FP8 quantization + quant_out, _ = quant_fp8_layer.forward_native(norm_out, scale=scale_factor) + + if residual is not None: + return quant_out, residual_out + else: + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + quant_input = norm_out + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + quant_input = norm_out + residual_out = allreduce_out + + # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +# Compiled versions of native functions +@torch.compile +def standard_allreduce_rmsnorm_native_compiled( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + norm_out: torch.Tensor | None = None, +): + """Compiled version of standard allreduce + rmsnorm.""" + return standard_allreduce_rmsnorm_native( + input_tensor, residual, rmsnorm_layer, norm_out + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp8_quant_native_compiled( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" + return standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor, + residual, + rmsnorm_layer, + quant_fp8_layer, + scale_factor, + norm_out, + quant_out, + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp4_quant_native_compiled( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" + return standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor, + residual, + rmsnorm_layer, + input_global_scale, + quant_out, + output_scale, + norm_out, + ) + + +def create_test_tensors( + seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + seq_len: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: FlashInferFusedAllReduceParams | None, + quant_mode: str = "all", + disable_oneshot: bool = False, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps, dtype=dtype) + rmsnorm_layer.weight.data = rms_gamma + quant_fp8_layer = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + + if quant_mode in ["all", "none"]: + # Standard AllReduce + RMSNorm + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + results["standard_allreduce_rmsnorm"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results["standard_allreduce_rmsnorm"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") + + if quant_mode in ["all", "fp8_only"]: + # Standard AllReduce + RMSNorm + FP8 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + quant_fp8_layer=quant_fp8_layer, + scale_factor=scale_fp8, + norm_out=norm_out, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( + "inf" + ) + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + "inf" + ) + + if quant_mode in ["all", "fp4_only"]: + # Standard AllReduce + RMSNorm + FP4 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results( + results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode, input_size_mb +): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print( + f"Results: seq_len={seq_len}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_mode={quant_mode}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results + +**World Size:** {world_size} +**Hidden Dimension:** {args.hidden_dim} +**Warmup Iterations:** {args.warmup} +**Benchmark Trials:** {args.trials} +**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} + +--- + +""" + + for result in all_results: + seq_len = result["seq_len"] + dtype = result["dtype"] + use_residual = result["use_residual"] + results_dict = result["results"] + input_size_mb = result["input_size_mb"] + residual_str = "with residual" if use_residual else "no residual" + + markdown += f""" +## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} +**Input Size:** {input_size_mb:.2f} MB + +| Operation | Time (ms) | Speedup | +|-----------|-----------|---------| +""" + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + # Format operation name for better readability + formatted_op_name = result["operation"].replace("_", " ").title() + markdown += f"| {formatted_op_name} | {result['time_str']} |" + markdown += f"{result['speedup_str']} |\n" + + markdown += "\n" + + return markdown + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "w") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Sequence lengths to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + # Quantization mode options (mutually exclusive with --no-quant) + quant_group = parser.add_mutually_exclusive_group() + quant_group.add_argument( + "--no-quant", action="store_true", help="Skip all quantization tests" + ) + quant_group.add_argument( + "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" + ) + quant_group.add_argument( + "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" + ) + quant_group.add_argument( + "--quant-all", + action="store_true", + help="Run all quantization tests (default)", + ) + + parser.add_argument( + "--disable-oneshot", + action="store_true", + help="Disable oneshot mode for FlashInfer operations", + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Determine quantization mode + if args.no_quant: + quant_mode = "none" + elif args.quant_fp8: + quant_mode = "fp8_only" + elif args.quant_fp4: + quant_mode = "fp4_only" + else: # args.quant_all or default + quant_mode = "all" + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization mode: %s", quant_mode) + if flashinfer_comm is not None: + oneshot_status = "enabled" if not args.disable_oneshot else "disabled" + logger.info( + "FlashInfer available - will benchmark fused operations (oneshot: %s)", + oneshot_status, + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + if not args.no_residual: + residual_options.append(False) + + configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for seq_len, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", + seq_len, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + seq_len, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_mode=quant_mode, + disable_oneshot=args.disable_oneshot, + ) + + # Store results for markdown export + if rank == 0: + # Calculate input size in MB + input_size_mb = ( + seq_len * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) + all_results.append( + { + "seq_len": seq_len, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_mode": quant_mode, + "input_size_mb": input_size_mb, + "results": results, + } + ) + + print_results( + results, + seq_len, + args.hidden_dim, + dtype, + use_residual, + quant_mode, + input_size_mb, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..68389ccfbe14 --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class TestModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +def test_compile_ranges(): + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + ) + ) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a..45d2c7f267d7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,7 +80,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -89,11 +89,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: tuple[int, int] | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: inductor_partition_ops = resolve_defined_ops( self.compilation_config.splitting_ops @@ -150,26 +150,28 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape + handle, graph, example_inputs, graph_index, compile_range ) - if runtime_shape is None: + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", + "Directly load the %s-th graph for compile range %s" + "from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -183,7 +185,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -195,15 +197,14 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s", @@ -211,9 +212,9 @@ def compile( ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " + "Directly load the compiled graph(s) for compile range %s " "from the cache, took %.3f s", - str(runtime_shape), + str(compile_range), elapsed, ) return compiled_graph @@ -224,14 +225,18 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): + maybe_key = "artifact_compile_range_" + if compile_range is None: + maybe_key += "dynamic_shape" + else: + maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"_subgraph_{graph_index}" + with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, example_inputs, additional_inductor_config, - runtime_shape, + compile_range, maybe_key, ) @@ -239,33 +244,34 @@ def compile( # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info_once( "Cache the graph for dynamic shape for later use", scope="local" ) else: logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), + "Cache the graph of compile range %s for later use", + str(compile_range), scope="local", ) - if runtime_shape is None: + if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", + "Store the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", + "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -275,16 +281,16 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) else: logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), elapsed, scope="local", ) @@ -407,19 +413,7 @@ def call_module( sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -429,7 +423,6 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7294ddce64ba..6e23f63ad6d4 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,7 +9,6 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -432,7 +431,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,7 +441,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -454,31 +455,21 @@ def __call__(self, graph: fx.Graph): _FI_WORKSPACE_TENSOR = None MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB + # Max size of the input tensor per world size per device capability + # to use flashinfer one shot fused allreduce + _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { + "9.0": { + 2: 32 * MiB, # 32MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 32 * MiB, # 32MB + 4: 4 * MiB, # 4MB + 8: 1 * MiB, # 1MB + }, } - try: - _FI_MAX_SIZES.update( - { - int(k): int(float(v) * MiB) - for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - } - ) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) - ) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -491,7 +482,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -500,88 +490,60 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " + f"element size {element_size}" ) - if use_flashinfer: - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=True, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + device_capability = current_platform.get_device_capability().as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size + ) + + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -595,7 +557,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -629,7 +590,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -637,9 +597,7 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -649,7 +607,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1119,23 +1076,29 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.pass_config.flashinfer_max_size( + self.tp_size + ) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) - // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + element_size = 4 if use_fp32_lamport else 2 + self.max_token_num = max_size // (self.hidden_dim * element_size) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens ) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1148,10 +1111,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + max_token_num=self.max_token_num, ) self.register_patterns() @@ -1204,6 +1164,11 @@ def register_patterns(self): self.disabled = False + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + if compile_range is None: + return False + return compile_range[1] - 1 <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..d069769fe76f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,16 +63,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -192,18 +193,21 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if isinstance(compile_range, tuple): + if compile_range[0] == compile_range[1]: + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" @@ -230,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -294,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -308,7 +312,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -493,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -589,9 +593,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -611,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..599fa776b6c0 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -28,8 +28,8 @@ class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: tuple[int, int] | None): + self.compile_range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +39,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: tuple[int, int] | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: tuple[int, int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3bc35a8f7198..08002dc862f6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -69,13 +69,13 @@ def __init__(self): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..7a10fed1d237 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,7 +7,6 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig @@ -17,8 +16,8 @@ @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: tuple[int, int] compiled: bool = False runnable: Callable = None # type: ignore @@ -31,7 +30,6 @@ def __init__( piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -55,67 +53,98 @@ def __init__( self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) - self.first_run_finished = False + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + self.is_in_range = ( + lambda x, range: range[0] <= x < range[1] + if range[0] < range[1] + else x == range[0] + ) - self.sym_shape_indices = sym_shape_indices + self.first_run_finished = False - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} + + # the entries for ranges that we need to either + self.range_entries: dict[tuple[int, int], RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - runnable=self.compiled_graph_for_general_shape, + for size in self.compile_sizes: + range = (size, size) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - runtime_shape = args[self.sym_shape_indices[0]] - - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, + compile_range=range_entry.compile_range, ) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + + # Role of the general graph is taken by the last range graph + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + runtime_shape = args[self.sym_shape_indices[0]] - return entry.runnable(*args) + range_found = False + if runtime_shape in self.compile_sizes: + range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_found = True + else: + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + range_found = True + break + assert range_found, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) + + self._maybe_compile_for_range_entry(range_entry, args) + + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..cf47adb4670a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -482,7 +482,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -502,7 +502,11 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return ( + compile_range is not None + and (compile_range[0] == compile_range[1]) + and (compile_range[1] % tp_size == 0) + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c24a94091be4..260b41b6e5ad 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -110,11 +110,42 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: dict[int, float] = field(default_factory=dict) + """The thresholds of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + dictionary mapping each world size to the threshold in MB + { : } + Unspecified world sizes will fallback to + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> int | None: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. + """ + + MiB = 1024 * 1024 + max_sizes = { + k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() + } + + # return None if world size is not supported by flashinfer + return max_sizes.get(world_size) + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -135,6 +166,35 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work" + ) + + # import here to avoid circular dependencies + from vllm.platforms import current_platform + + # Default tuned max size of the input tensor + # per world size per device capability + # to use flashinfer fused allreduce + fi_allreduce_fusion_max_size_mb = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + } + device_capability = current_platform.get_device_capability().as_version_str() + + max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes.update(self.fi_allreduce_fusion_max_size_mb) + self.fi_allreduce_fusion_max_size_mb = max_sizes @config @@ -164,6 +224,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -281,6 +343,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: list[int] | None = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -864,3 +936,14 @@ def custom_op_log_check(self): enable_str, op, ) + + def get_compile_ranges(self) -> list[tuple[int, int]]: + """Get the compile ranges for the compilation config.""" + split_points = self.compile_ranges_split_points + compile_ranges = [] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append((1, s)) + else: + compile_ranges.append((split_points[i - 1], s)) + return compile_ranges diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 916f258d6586..fd38992e374b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -426,6 +426,8 @@ def __post_init__(self): "correctness and to realize prefill savings. " ) + self._set_compile_ranges() + disable_chunked_prefill_reasons: list[str] = [] if self.model_config: @@ -796,6 +798,49 @@ def _set_cudagraph_sizes(self): # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py index 06dd336bf9cf..d083ece892d5 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json -import re import uuid from collections.abc import Sequence from typing import Any +import regex as re + from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 2031ade64b5f..cd73e8a249b5 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -2,9 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -import triton -import triton.language as tl +from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c144aa23e46e..4c9da724cacd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2184,33 +2184,59 @@ def forward_native( mode="constant", value=0.0, ) + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 and not self.quant_method.using_modular_kernel + ) - if self.shared_experts is None: + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + + if self.shared_experts is not None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) else: - fused_output = torch.ops.vllm.moe_forward( + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( hidden_states, router_logits, self.layer_name ) - return fused_output[..., :og_hidden_states] + return ( + reduce_output(shared_output[..., :og_hidden_states], do_combine=False), + reduce_output(fused_output[..., :og_hidden_states]), + ) else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], - ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return ( + reduce_output(fused_output[..., :og_hidden_states]) + + zero_expert_result + ) + else: + return reduce_output(fused_output[..., :og_hidden_states]) def forward_cuda( self, @@ -2492,35 +2518,7 @@ def forward_impl( shared_output, final_hidden_states, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states - - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - - return states - - if self.shared_experts is not None: - return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), - ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result - else: - return reduce_output(final_hidden_states) + return final_hidden_states @classmethod def make_expert_params_mapping( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..ef953dd2051e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -330,7 +330,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes