From 32e8ed0c2ee6b02a1b48b424853f36804f4eada1 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 14 Oct 2025 17:33:07 -0700 Subject: [PATCH 1/6] [compile] Enable sequence parallelism matching w/o custom ops enabled Signed-off-by: angelayi --- tests/compile/test_sequence_parallelism.py | 50 ++-- vllm/compilation/sequence_parallelism.py | 276 ++++++--------------- 2 files changed, 106 insertions(+), 220 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index e909cf7393ad..f0266f56952f 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -99,7 +99,6 @@ def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) @@ -152,16 +151,12 @@ def forward(self, hidden_states, residual): def ops_in_model_before(self): ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # The following are only removed if fusion happens - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - ops_to_remove.extend( - [ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ] - ) + config = get_current_vllm_config() + if config.compilation_config.pass_config.enable_fusion: + ops_to_remove.append(torch.ops._C.fused_add_rms_norm.default) + # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled + if "+quant_fp8" in config.compilation_config.custom_ops: + ops_to_remove.append(torch.ops._C.static_scaled_fp8_quant.default) return ops_to_remove def ops_in_model_after(self): @@ -169,24 +164,23 @@ def ops_in_model_after(self): torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.all_gather.default, ] - # The following is only added if fusion happens + # The following is only added if fusion happens and custom quant_fp8 is enabled + config = get_current_vllm_config() if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion + config.compilation_config.pass_config.enable_fusion + and "+quant_fp8" in config.compilation_config.custom_ops ): ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) return ops_to_add def ops_in_model(self): - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - # If fusion happens, the fused op is the one + config = get_current_vllm_config() + if config.compilation_config.pass_config.enable_fusion: + # If fusion happens with custom quant_fp8, the fused op is the one # we check for (de)functionalization return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] else: - # If no fusion, the original ops are checked + # If no fusion or using native quant, the original ops are checked return [ torch.ops._C.fused_add_rms_norm.default, # TODO functionalization pass does not handle this yet @@ -195,7 +189,14 @@ def ops_in_model(self): @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) +@pytest.mark.parametrize( + "test_model_cls, custom_ops", + [ + (TestModel, ""), + (TestQuantModel, "+quant_fp8"), + (TestQuantModel, "-quant_fp8"), + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @@ -204,6 +205,7 @@ def ops_in_model(self): @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, @@ -220,6 +222,7 @@ def run_torch_spawn(fn, nprocs): args=( num_processes, test_model_cls, + custom_ops, batch_size, seq_len, hidden_size, @@ -236,6 +239,7 @@ def sequence_parallelism_pass_on_test_model( local_rank: int, world_size: int, test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, @@ -264,12 +268,14 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass + custom_ops_list = custom_ops.split(",") if custom_ops else [] compilation_config = CompilationConfig( + custom_ops=custom_ops_list, pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, - ) + ), ) # NoOp needed for fusion device_config = DeviceConfig(device=torch.device("cuda")) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..caca08c394b5 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -10,109 +10,30 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) -class _RMSNormAndQuantOpHelper: - """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" +class _SequenceParallelPatternHelper: + """Helper for sequence parallelism patterns.""" def __init__( self, epsilon: float, dtype: torch.dtype, device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, ): self.epsilon = epsilon self.dtype = dtype self.device = device - self.quant_op = quant_op - - def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=result_buffer, - input=input_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_fused_add_rmsnorm( - self, input_tensor, residual_tensor, weight_tensor - ): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=input_tensor, - residual=residual_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_rmsnorm_then_quant( - self, - rmsnorm_result_buffer, - quant_result_buffer, - input_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - rmsnorm_out_tuple = self._functional_rmsnorm( - rmsnorm_result_buffer, input_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple - - def _functional_fused_add_rmsnorm_then_quant( - self, - quant_result_buffer, - input_tensor, - residual_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( - input_tensor, residual_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] - - -class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): - """Helper for sequence parallelism patterns.""" - - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, - ): - super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) self.tp_group = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() @@ -131,36 +52,34 @@ def _all_gather(self, x: torch.Tensor) -> torch.Tensor: class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - return [input, permute, arg3_1] + return [input, arg3_1] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): all_reduce = self._all_reduce(input) - rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) + rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) - return rmsnorm[1], all_reduce + return rmsnorm, all_reduce def replacement( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) - - all_gather = self._all_gather(rmsnorm[1]) - + rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) + all_gather = self._all_gather(rmsnorm) return all_gather, reduce_scatter pm.register_replacement( @@ -169,6 +88,10 @@ def replacement( class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -188,10 +111,8 @@ def pattern( rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1], rmsnorm[2] + rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + return rmsnorm[0], rmsnorm[1] def replacement( residual: torch.Tensor, @@ -199,11 +120,9 @@ def replacement( rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - all_gather = self._all_gather(rmsnorm[1]) - return all_gather, rmsnorm[2] + rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + all_gather = self._all_gather(rmsnorm[0]) + return all_gather, rmsnorm[1] pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass @@ -211,9 +130,12 @@ def replacement( class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -228,23 +150,19 @@ def pattern( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1] + rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + return rmsnorm[0] def replacement( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - normalized = self._all_gather(rmsnorm[1]) + rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + normalized = self._all_gather(rmsnorm[0]) return normalized pm.register_replacement( @@ -257,52 +175,41 @@ def replacement( class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + self, + epsilon: float, + dtype: torch.dtype, + device: str, ): - super().__init__(epsilon, dtype, device, quant_op=op) + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [input, weight, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = self._all_reduce(input) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, all_reduce, weight, scale - ) - return static_fp8[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce def replacement( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - - rmsnorm_result = torch.empty_like( - reduce_scatter, dtype=rmsnorm_result.dtype - ) - quant_result = torch.empty_like( - rmsnorm_result, # Output of RMSNorm - dtype=quant_result.dtype, - ) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, reduce_scatter, weight, scale - ) - all_gather = self._all_gather(static_fp8[1]) + rms = self.rmsnorm_matcher(reduce_scatter, weight) + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) return all_gather, reduce_scatter @@ -312,59 +219,46 @@ def replacement( class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] + return [residual, mm_1, rms_norm_weights, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - result, all_reduce, residual, rms_norm_weights, scale - ) + rms, residual_out = self.rmsnorm_matcher( + all_reduce, rms_norm_weights, residual ) - return static_fp8[1], rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + return quant, residual_out def replacement( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) + rms, residual_out = self.rmsnorm_matcher( + reduce_scatter, rms_norm_weights, residual ) - all_gather = self._all_gather(static_fp8[1]) - return all_gather, rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) + return all_gather, residual_out pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass @@ -372,54 +266,41 @@ def replacement( class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] + return [residual, mm_1, rms_norm_weights, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: all_reduce = self._all_reduce(mm_1) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - result, all_reduce, residual, rms_norm_weights, scale - ) - return static_fp8[1] + rms, _ = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + quant, _ = self.quant_matcher(rms, scale) + return quant def replacement( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) - normalized = self._all_gather(static_fp8[1]) + rms, _ = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + quant, _ = self.quant_matcher(rms, scale) + normalized = self._all_gather(quant) return normalized pm.register_replacement( @@ -457,15 +338,14 @@ def __init__(self, config: VllmConfig): for epsilon in [1e-5, 1e-6]: # RMSNorm + Static FP8 quantization patterns - fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default FirstAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) # Normal RMSNorm patterns From dc60743f5ac155b650847394bb24a5c74c04d819 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 17 Oct 2025 13:01:05 -0700 Subject: [PATCH 2/6] [compile] Fix rmsnorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: angelayi Co-authored-by: Luka Govedič --- tests/compile/test_sequence_parallelism.py | 255 ++++++++++----------- vllm/compilation/sequence_parallelism.py | 143 +++++------- vllm/config/vllm.py | 2 - 3 files changed, 179 insertions(+), 221 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index f0266f56952f..98c213f2e382 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -5,9 +5,8 @@ import torch import vllm.envs as envs -from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import RMSNormQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.fx_utils import find_auto_fn from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass @@ -27,6 +26,7 @@ initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -43,158 +43,139 @@ ] -class TestModel(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): +class TestAllReduceRMSNormModel(torch.nn.Module): + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size)) - ) - self.norm = RMSNorm(intermediate_size, 1e-05) - # Initialize weights - torch.nn.init.normal_(self.gate_proj, std=0.02) + self.eps = eps + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] - def forward(self, hidden_states, residual): - """ - Forward pass implementing the operations in the FX graph + def forward(self, x): + z = torch.relu(x) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - Args: - hidden_states: Input tensor - residual: Residual tensor from previous layer + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - Returns: - Tuple containing the output tensor - """ - # Reshape input - view = hidden_states.reshape(-1, self.hidden_size) + y2, resid = self.norm[1](x2, resid) - # matrix multiplication - permute = self.gate_proj.permute(1, 0) - mm = torch.mm(view, permute) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) - # Tensor parallel all-reduce - all_reduce = tensor_model_parallel_all_reduce(mm) + y3, resid = self.norm[2](x3, resid) - # layer normalization - norm_output, residual_output = self.norm(all_reduce, residual) + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - return norm_output, residual_output + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): return [ - torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.all_gather.default, + torch.ops.vllm.reduce_scatter.default, ] def ops_in_model(self): - return [torch.ops._C.fused_add_rms_norm.default] + if RMSNorm.enabled(): + return [ + torch.ops._C.rms_norm.default, + torch.ops._C.fused_add_rms_norm.default, + ] + else: + return [] -class TestQuantModel(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() + self.vllm_config = get_current_vllm_config() self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size)), requires_grad=False + self.eps = eps + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] + + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, ) - self.norm = RMSNorm(intermediate_size, 1e-05) - # Initialize weights - torch.nn.init.normal_(self.gate_proj, std=0.02) - - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - - self.scale = torch.rand(1, dtype=torch.float32) - # Create a weight that is compatible with torch._scaled_mm, - # which expects a column-major layout. - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) - - def forward(self, hidden_states, residual): - """ - Forward pass implementing the operations in the FX graph - - Args: - hidden_states: Input tensor - residual: Residual tensor from previous layer - - Returns: - Tuple containing the output tensor - """ - # Reshape input - view = hidden_states.reshape(-1, self.hidden_size) - - # matrix multiplication - permute = self.gate_proj.permute(1, 0) - mm = torch.mm(view, permute) - - # Tensor parallel all-reduce - all_reduce = tensor_model_parallel_all_reduce(mm) - - # layer normalization - norm_output, residual_output = self.norm(all_reduce, residual) - - # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), + + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] ) - return fp8_linear_result, residual_output + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) - def ops_in_model_before(self): - ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP - # The following are only removed if fusion happens - config = get_current_vllm_config() - if config.compilation_config.pass_config.enable_fusion: - ops_to_remove.append(torch.ops._C.fused_add_rms_norm.default) - # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled - if "+quant_fp8" in config.compilation_config.custom_ops: - ops_to_remove.append(torch.ops._C.static_scaled_fp8_quant.default) - return ops_to_remove + z3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): - ops_to_add = [ - torch.ops.vllm.reduce_scatter.default, + return [ torch.ops.vllm.all_gather.default, + torch.ops.vllm.reduce_scatter.default, + ] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, ] - # The following is only added if fusion happens and custom quant_fp8 is enabled - config = get_current_vllm_config() - if ( - config.compilation_config.pass_config.enable_fusion - and "+quant_fp8" in config.compilation_config.custom_ops - ): - ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) - return ops_to_add def ops_in_model(self): - config = get_current_vllm_config() - if config.compilation_config.pass_config.enable_fusion: - # If fusion happens with custom quant_fp8, the fused op is the one - # we check for (de)functionalization + if self.vllm_config.compilation_config.pass_config.enable_fusion: return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] - else: - # If no fusion or using native quant, the original ops are checked + elif RMSNorm.enabled(): return [ torch.ops._C.fused_add_rms_norm.default, - # TODO functionalization pass does not handle this yet - # torch.ops._C.static_scaled_fp8_quant.default, ] + elif self.fp8_linear.quant_fp8.enabled(): + return [ + torch.ops._C.static_scaled_fp8_quant.default, + ] + else: + return [] @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "test_model_cls, custom_ops", [ - (TestModel, ""), - (TestQuantModel, "+quant_fp8"), - (TestQuantModel, "-quant_fp8"), + (TestAllReduceRMSNormModel, "+rms_norm"), + (TestAllReduceRMSNormModel, "-rms_norm"), + (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"), + (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"), ], ) @pytest.mark.parametrize("batch_size", [8]) @@ -202,6 +183,7 @@ def ops_in_model(self): @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) +@pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( test_model_cls: type[torch.nn.Module], @@ -211,6 +193,7 @@ def test_sequence_parallelism_pass( hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): num_processes = 2 @@ -228,6 +211,7 @@ def run_torch_spawn(fn, nprocs): hidden_size, dtype, enable_fusion, + dynamic, ), nprocs=nprocs, ) @@ -245,6 +229,7 @@ def sequence_parallelism_pass_on_test_model( hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): current_platform.seed_everything(0) @@ -295,7 +280,6 @@ def sequence_parallelism_pass_on_test_model( with set_current_vllm_config(vllm_config): noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) assert ( sequence_parallelism_pass.compilation_config.splitting_ops @@ -316,38 +300,41 @@ def sequence_parallelism_pass_on_test_model( passes_for_backend.append(cleanup_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + backend = TestBackend(*passes_for_backend) - model = test_model_cls(hidden_size, hidden_size * 2) + model = test_model_cls(hidden_size) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + assert sequence_parallelism_pass.matched_count == 4 # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + pre_ops = [ + node.target + for node in backend.graph_pre_pass.nodes + if node.op == "call_function" + ] + for op in model.ops_in_model_before(): + num_op = len([pre_op for pre_op in pre_ops if pre_op == op]) + assert num_op == 4 # In post-nodes, reduce scatter and all gather should be there, # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + post_ops = [ + node.target + for node in backend.graph_post_pass.nodes + if node.op == "call_function" + ] + for op in model.ops_in_model_after(): + num_op = len([post_op for post_op in post_ops if post_op == op]) + assert num_op == 4 - # check if the functionalization pass is applied for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None - - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: - for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend.graph_post_pass.nodes, op) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index caca08c394b5..bb4dcf12d865 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx @@ -17,11 +19,20 @@ from .inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .noop_elimination import NoOpEliminationPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) +def get_first_out_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args): + return fn(*args)[0] + + return wrapper + + class _SequenceParallelPatternHelper: """Helper for sequence parallelism patterns.""" @@ -119,54 +130,26 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # pattern matcher replaces from top-to-bottom, + # so residual is still the full size here. + # once the seqpar pattern with the previous rmsnorm is replaced reduce_scatter = self._reduce_scatter(mm_1) + residual = residual[0 : reduce_scatter.size(0), ...] rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) all_gather = self._all_gather(rmsnorm[0]) + # shape of residual changes but that's fine, + # next node is already slicing it, now becomes a noop return all_gather, rmsnorm[1] pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) - - -class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): - super().__init__(epsilon, dtype, device) - self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - - def get_inputs(self): - mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - - return [ - residual, - mm_1, - rms_norm_weights, - ] - - def register(self, pm_pass: PatternMatcherPass): - def pattern( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - ) -> torch.Tensor: - all_reduce = self._all_reduce(mm_1) - rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) - return rmsnorm[0] - - def replacement( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - ) -> torch.Tensor: - reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) - normalized = self._all_gather(rmsnorm[0]) - return normalized - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + get_first_out_wrapper(pattern), + get_first_out_wrapper(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, ) @@ -252,59 +235,31 @@ def replacement( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # pattern matcher replaces from top-to-bottom, + # so residual is still the full size here. + # add a temporary slice which will become a noop + # once the seqpar pattern with the previous rmsnorm is replaced reduce_scatter = self._reduce_scatter(mm_1) + residual = residual[0 : reduce_scatter.size(0), ...] rms, residual_out = self.rmsnorm_matcher( reduce_scatter, rms_norm_weights, residual ) quant, _ = self.quant_matcher(rms, scale) all_gather = self._all_gather(quant) + # shape of residual changes but that's fine, + # next node is already slicing it, now becomes a noop return all_gather, residual_out pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) - -class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): - super().__init__(epsilon, dtype, device) - self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - - def get_inputs(self): - mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [residual, mm_1, rms_norm_weights, scale] - - def register(self, pm_pass: PatternMatcherPass): - def pattern( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - scale: torch.Tensor, - ) -> torch.Tensor: - all_reduce = self._all_reduce(mm_1) - rms, _ = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) - quant, _ = self.quant_matcher(rms, scale) - return quant - - def replacement( - residual: torch.Tensor, - mm_1: torch.Tensor, - rms_norm_weights: torch.Tensor, - scale: torch.Tensor, - ) -> torch.Tensor: - reduce_scatter = self._reduce_scatter(mm_1) - rms, _ = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) - quant, _ = self.quant_matcher(rms, scale) - normalized = self._all_gather(quant) - return normalized - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + get_first_out_wrapper(pattern), + get_first_out_wrapper(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, ) @@ -326,12 +281,34 @@ class SequenceParallelismPass(VllmPatternMatcherPass): GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can significantly reduce communication overhead and improve overall model performance. + + + This pass splits up the residual tensor across TP ranks and hence divides its size. + Because the pattern matcher starts at the end of the graph, the replacement + contains a slice that temporarily conforms the input residual to the correct size. + After all patterns have been matched, we use a NoOpEliminationPass to clean up + what have now become no-op slices. + + Note that an older version of the pass did not need this as it operated only on + custom rms_norm and fused_rms_norm_add custom ops which did not complain about + mismatched shapes during replacement. So this approach has the same assumption that + correctness is only maintained if all rms_norm operations are split across ranks. + + Correctness-wise, this is approach strictly better than before - before, + the graph was incorrect semantically and shape-wise during the pass. + With this approach there's only semantic incorrectness during the pass. + Both approaches restore a correct graph once all patterns are matched. """ @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) + # Used to cleanup redundant views created temporarily + # to circumvent residual shape change issues + self.noop_cleanup = NoOpEliminationPass(config) + self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}" + self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass" ) @@ -344,9 +321,6 @@ def __init__(self, config: VllmConfig): MiddleAllReduceRMSNormStaticFP8Pattern( epsilon, self.model_dtype, self.device ).register(self.patterns) - LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device - ).register(self.patterns) # Normal RMSNorm patterns FirstAllReduceRMSNormPattern( @@ -357,9 +331,6 @@ def __init__(self, config: VllmConfig): epsilon, self.model_dtype, self.device ).register(self.patterns) - LastAllReduceRMSNormPattern( - epsilon, self.model_dtype, self.device - ).register(self.patterns) self.dump_patterns(config, self.patterns) def is_applicable(self, shape: int | None) -> bool: @@ -388,3 +359,5 @@ def is_applicable(self, shape: int | None) -> bool: def __call__(self, graph: fx.Graph): self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) + # Clean up reshape nodes + self.noop_cleanup(graph) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a7f7f3b45abe..911372bd9c17 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -364,8 +364,6 @@ def __post_init__(self): # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: self.compilation_config.pass_config.enable_sequence_parallelism = True - if self.compilation_config.pass_config.enable_sequence_parallelism: - self.compilation_config.custom_ops.append("+rms_norm") if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default From 88cdee0a1ff766b51fbc0b0cedc6c3e0265ebc2e Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 21 Oct 2025 10:52:40 -0700 Subject: [PATCH 3/6] Add e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: angelayi --- tests/compile/test_fusions_e2e.py | 183 ++++++++++++++++---- tests/distributed/test_sequence_parallel.py | 1 - vllm/config/vllm.py | 7 + 3 files changed, 158 insertions(+), 33 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index d66c60ccb5b2..68119e2de148 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -21,12 +21,18 @@ from ..utils import flat_product, multi_gpu_test +class Matches(NamedTuple): + attention_fusion: int = 0 + allreduce_fusion: int = 0 + sequence_parallel: int = 0 + async_tp: int = 0 + + class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] backend: _Backend - attention_fusions: int - allreduce_fusions: int | None = None + matches: Matches MODELS_FP8: list[ModelBackendTestCase] = [] @@ -40,15 +46,23 @@ class ModelBackendTestCase(NamedTuple): model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=32, - allreduce_fusions=65, + matches=Matches( + attention_fusion=32, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=_Backend.FLASHINFER, - attention_fusions=48, - allreduce_fusions=96, + matches=Matches( + attention_fusion=48, + allreduce_fusion=96, + sequence_parallel=96, + async_tp=190, + ), ), ] @@ -57,8 +71,12 @@ class ModelBackendTestCase(NamedTuple): model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=_Backend.FLASHINFER, - attention_fusions=48, - allreduce_fusions=96, + matches=Matches( + attention_fusion=48, + allreduce_fusion=96, + sequence_parallel=96, + async_tp=190, + ), ), ] @@ -68,8 +86,12 @@ class ModelBackendTestCase(NamedTuple): model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=0, - allreduce_fusions=65, + matches=Matches( + attention_fusion=0, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ] @@ -79,19 +101,19 @@ class ModelBackendTestCase(NamedTuple): model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.ROCM_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.ROCM_AITER_UNIFIED_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ] @@ -100,8 +122,7 @@ class ModelBackendTestCase(NamedTuple): @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) # quant_fp4 only has the custom impl @@ -112,8 +133,7 @@ def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], backend: _Backend, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, @@ -163,12 +183,12 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 1, log_holder.text - assert int(matches[0]) == attention_fusions + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.attention_fusion # TODO(luka) test both in nightly @@ -182,8 +202,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models list( flat_product( @@ -204,8 +223,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, backend: _Backend, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, @@ -253,23 +271,124 @@ def test_tp2_attn_quant_allreduce_rmsnorm( run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion + + log_matches = re.findall( + r"collective_fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.allreduce_fusion + assert int(log_matches[1]) == matches.allreduce_fusion + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="sequence parallel only tested on CUDA", +) +def test_tp2_attn_quant_async_tp( + model_name: str, + model_kwargs: dict, + backend: _Backend, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if current_platform.is_device_capability((10, 0)): + # TODO: https://github.com/vllm-project/vllm/issues/27893 + pytest.skip("Blackwell is not supported for AsyncTP pass") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_sequence_parallelism=True, + enable_async_tp=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + log_matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion + + log_matches = re.findall( + r"sequence_parallelism.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == attention_fusions - assert int(matches[1]) == attention_fusions + assert int(log_matches[0]) == matches.sequence_parallel + assert int(log_matches[1]) == matches.sequence_parallel - matches = re.findall( + log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == allreduce_fusions - assert int(matches[1]) == allreduce_fusions + assert int(log_matches[0]) == matches.async_tp + assert int(log_matches[1]) == matches.async_tp def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 94b2b51211a6..7b6204f296d2 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -244,7 +244,6 @@ def _compare_sp( compilation_config = { "mode": CompilationMode.VLLM_COMPILE, - "custom_ops": ["+rms_norm"], "compile_sizes": [4, 8], "pass_config": { "enable_sequence_parallelism": True, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 911372bd9c17..64b50669425f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -365,6 +365,13 @@ def __post_init__(self): if self.compilation_config.pass_config.enable_async_tp: self.compilation_config.pass_config.enable_sequence_parallelism = True + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and self.parallel_config.pipeline_parallel_size > 1 + ): + # TODO: https://github.com/vllm-project/vllm/issues/27894 + self.compilation_config.custom_ops.append("+rms_norm") + if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default # value From aadc6b2e3d2f2e2d177f617a9718283fddf11192 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 3 Nov 2025 13:25:12 -0800 Subject: [PATCH 4/6] Fix test_sequence_parallel + FP8 Signed-off-by: angelayi --- .buildkite/test-pipeline.yaml | 1 + tests/distributed/test_sequence_parallel.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d556073cd104..323cd06780e1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1198,6 +1198,7 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 ##### B200 test ##### diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 7b6204f296d2..43422b0cf4b3 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,6 +18,7 @@ from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS @@ -161,6 +162,7 @@ def _compare_sp( test_options: SPTestOptions, num_gpus_available: int, use_inductor_graph_partition: bool, + enable_async_tp: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -247,6 +249,7 @@ def _compare_sp( "compile_sizes": [4, 8], "pass_config": { "enable_sequence_parallelism": True, + "enable_async_tp": enable_async_tp, "enable_fusion": enable_fusion, "enable_noop": True, }, @@ -306,6 +309,7 @@ def _compare_sp( ], ) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +@pytest.mark.parametrize("enable_async_tp", [True, False]) @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -315,10 +319,19 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, use_inductor_graph_partition: bool, + enable_async_tp: bool, ): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # Skip FP8 SP-only test on sm89 (compute capability 8.9) + if ( + model_id == "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + and current_platform.get_device_capability() < (9, 0) + and (not enable_async_tp) + ): + pytest.skip("FP8 reduction support begins with sm90 capable devices.") + _compare_sp( model_id, parallel_setup, @@ -327,6 +340,7 @@ def test_tp_sp_generation( test_options, num_gpus_available, use_inductor_graph_partition, + enable_async_tp=enable_async_tp, method="generate", is_multimodal=False, ) From b1ff48efc1ca5609efcbf81f7a190a070e3ed2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 3 Nov 2025 17:28:13 -0500 Subject: [PATCH 5/6] more robust fp8 check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/distributed/test_sequence_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 43422b0cf4b3..ec233d046d0e 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -326,7 +326,7 @@ def test_tp_sp_generation( # Skip FP8 SP-only test on sm89 (compute capability 8.9) if ( - model_id == "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "fp8" in model_id.lower() and current_platform.get_device_capability() < (9, 0) and (not enable_async_tp) ): From f03f8c32d06e1a0947d3ae8ad263975c6b76479c Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 5 Nov 2025 10:13:00 -0800 Subject: [PATCH 6/6] disable async tp tests Signed-off-by: angelayi --- tests/distributed/test_sequence_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index ec233d046d0e..f38c509775ed 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -309,7 +309,7 @@ def _compare_sp( ], ) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) -@pytest.mark.parametrize("enable_async_tp", [True, False]) +@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str,