From e20928ae061936f13480d707767ba04a3937afd5 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 18 Apr 2025 16:30:04 +0000 Subject: [PATCH 1/3] fix moe padding Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/layer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 43fb311289fd..41b4bc1542fb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -113,12 +113,11 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight( + layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight( + layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) @@ -127,10 +126,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: From 20a10cd030054c0eb262cb62f0596d2ff4887c8c Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 18 Apr 2025 16:48:24 +0000 Subject: [PATCH 2/3] linting Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/layer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 41b4bc1542fb..3cdf3c97a7d3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -114,10 +114,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) # Padding the weight for better performance on ROCm - layer.w13_weight.data = self._maybe_pad_weight( - layer.w13_weight.data) - layer.w2_weight.data = self._maybe_pad_weight( - layer.w2_weight.data) + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) From 317f267df1527de4a39fb1b6027af63936c6dcb1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 28 Apr 2025 14:50:11 +0000 Subject: [PATCH 3/3] fix fp8 wvSplitKQ not called on mi250 mi300 Signed-off-by: charlifu --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0114a309b987..8ab45d610053 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, input_2d: torch.Tensor, output_shape: List) -> torch.Tensor: from vllm.platforms.rocm import on_mi250_mi300 - if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300( + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count())