Skip to content

Commit bde2a1a

Browse files
jpvillam-amdjpvillam
andauthored
[ROCm] Small functional changes for gptoss (#25201)
Signed-off-by: jpvillam <jpvillam@amd.com> Co-authored-by: jpvillam <jpvillam@amd.com>
1 parent 5e25b12 commit bde2a1a

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,15 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
212212
intermediate_size_per_partition_after_pad = round_up(
213213
intermediate_size_per_partition, 256)
214214
hidden_size = round_up(hidden_size, 256)
215-
elif current_platform.is_rocm() or (
216-
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
217-
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
215+
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
216+
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
218217
intermediate_size_per_partition_after_pad = round_up(
219218
intermediate_size_per_partition, 128)
220219
hidden_size = round_up(hidden_size, 128)
220+
elif current_platform.is_rocm():
221+
intermediate_size_per_partition_after_pad = round_up(
222+
intermediate_size_per_partition, 256)
223+
hidden_size = round_up(hidden_size, 256)
221224
else:
222225
intermediate_size_per_partition_after_pad = round_up(
223226
intermediate_size_per_partition, 64)

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Callable, Optional
3+
from typing import Any, Callable, Optional
44

55
import torch
66

@@ -21,15 +21,26 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
2121
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
2222
from triton_kernels.tensor_details import layout
2323
from triton_kernels.tensor_details.layout import StridedLayout
24+
25+
value_layout_opts: dict[str, Any] = {}
26+
scale_layout_opts: dict[str, Any] = {}
27+
2428
if (current_platform.is_cuda()
2529
and current_platform.is_device_capability(90)
2630
and not is_torch_equal_or_newer("2.8.1")):
2731
logger.warning_once(
2832
"Mxfp4 on hopper is running on torch < 2.8.1, "
2933
"this cause swizling to be disabled, which may "
3034
"cause performance degradation. Please upgrade to torch nightly")
31-
value_layout, value_layout_opts = StridedLayout, dict()
32-
scale_layout, scale_layout_opts = StridedLayout, dict()
35+
value_layout = StridedLayout
36+
scale_layout = StridedLayout
37+
elif current_platform.is_rocm():
38+
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
39+
StridedLayout)
40+
41+
from vllm.platforms.rocm import on_gfx950
42+
value_layout = StridedLayout
43+
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
3344
else:
3445
value_layout, value_layout_opts = \
3546
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)

vllm/platforms/rocm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def on_gfx9() -> bool:
118118
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
119119

120120

121+
@cache
122+
def on_gfx950() -> bool:
123+
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
124+
return any(arch in GPU_ARCH for arch in ["gfx950"])
125+
126+
121127
@cache
122128
def use_rocm_custom_paged_attention(
123129
qtype: torch.dtype,

0 commit comments

Comments
 (0)