Skip to content

Commit 04bcde1

Browse files
maleksan85Aleksandr MalyshevDoug Lehr
authored andcommitted
Llamas 3.1 405B fp4 changes upstreaming from 355_wip (vllm-project#25135)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 8c17428 commit 04bcde1

File tree

6 files changed

+301
-38
lines changed

6 files changed

+301
-38
lines changed

vllm/envs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
107107
VLLM_ROCM_USE_AITER_MLA: bool = True
108108
VLLM_ROCM_USE_AITER_MHA: bool = True
109+
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
110+
VLLM_ROCM_USE_TRITON_ROPE: bool = False
109111
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
110112
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
111113
VLLM_ROCM_FP8_PADDING: bool = True
@@ -934,6 +936,18 @@ def get_vllm_port() -> Optional[int]:
934936
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
935937
("true", "1")),
936938

939+
# Whether to use aiter fp4 gemm asm.
940+
# By default is disabled.
941+
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM":
942+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in
943+
("true", "1")),
944+
945+
# Whether to use aiter rope.
946+
# By default is disabled.
947+
"VLLM_ROCM_USE_TRITON_ROPE":
948+
lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in
949+
("true", "1")),
950+
937951
# Whether to use aiter triton fp8 bmm kernel
938952
# By default is enabled.
939953
"VLLM_ROCM_USE_AITER_FP8BMM":
@@ -1539,6 +1553,8 @@ def compute_hash() -> str:
15391553
"VLLM_ROCM_USE_AITER_RMSNORM",
15401554
"VLLM_ROCM_USE_AITER_MLA",
15411555
"VLLM_ROCM_USE_AITER_MHA",
1556+
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
1557+
"VLLM_ROCM_USE_TRITON_ROPE",
15421558
"VLLM_ROCM_USE_AITER_FP8BMM",
15431559
"VLLM_ROCM_USE_SKINNY_GEMM",
15441560
"VLLM_ROCM_FP8_PADDING",

vllm/model_executor/layers/linear.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def __init__(
323323
return_bias: bool = True,
324324
disable_tp: bool = False,
325325
):
326+
# If MergedReplicatedLinear, use output size of each partition.
327+
if hasattr(self, "output_sizes"):
328+
self.output_partition_sizes = self.output_sizes
329+
else:
330+
self.output_partition_sizes = [output_size]
331+
326332
super().__init__(input_size,
327333
output_size,
328334
skip_bias_add,
@@ -335,7 +341,8 @@ def __init__(
335341
# All the linear layer supports quant method.
336342
assert self.quant_method is not None
337343
self.quant_method.create_weights(self,
338-
self.input_size, [self.output_size],
344+
self.input_size,
345+
self.output_partition_sizes,
339346
self.input_size,
340347
self.output_size,
341348
self.params_dtype,
@@ -374,12 +381,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
374381
param.data.copy_(loaded_weight)
375382

376383
def forward(
377-
self, x: torch.Tensor
384+
self,
385+
x: torch.Tensor,
378386
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
379387
bias = self.bias if not self.skip_bias_add else None
380388
assert self.quant_method is not None
389+
381390
output = self.quant_method.apply(self, x, bias)
382391
output_bias = self.bias if self.skip_bias_add else None
392+
383393
if not self.return_bias:
384394
return output
385395
return output, output_bias
@@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase):
413423
output_sizes: list of output sizes packed into one output, like for QKV
414424
the list would be size 3.
415425
prefix: The name of the layer in the state dict, including all parents
416-
(e.g. model.layers.0.qkv_proj)
426+
(e.g. model.layers.0.qkv_proj)
417427
return_bias: If true, return bias together with outputs in forward pass.
418428
disable_tp: If true, weights matrix won't be sharded through tp rank.
419429
"""
@@ -535,13 +545,15 @@ def weight_loader_v2(self, param: BasevLLMParameter,
535545
param.load_column_parallel_weight(loaded_weight=loaded_weight)
536546

537547
def forward(
538-
self, input_
548+
self,
549+
input_,
539550
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
540551
bias = self.bias if not self.skip_bias_add else None
541552

542553
# Matrix multiply.
543554
assert self.quant_method is not None
544555
output_parallel = self.quant_method.apply(self, input_, bias)
556+
545557
if self.gather_output and self.tp_size > 1:
546558
# All-gather across the partitions.
547559
output = tensor_model_parallel_all_gather(output_parallel)
@@ -1326,7 +1338,8 @@ def weight_loader_v2(self, param: BasevLLMParameter,
13261338
param.load_row_parallel_weight(loaded_weight=loaded_weight)
13271339

13281340
def forward(
1329-
self, input_
1341+
self,
1342+
input_,
13301343
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
13311344
if self.input_is_parallel:
13321345
input_parallel = input_
@@ -1340,9 +1353,8 @@ def forward(
13401353
# Only fuse bias add into GEMM for rank 0 (this ensures that
13411354
# bias will not get added more than once in TP>1 case)
13421355
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1343-
output_parallel = self.quant_method.apply(self,
1344-
input_parallel,
1345-
bias=bias_)
1356+
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1357+
13461358
if self.reduce_results and self.tp_size > 1:
13471359
output = tensor_model_parallel_all_reduce(output_parallel)
13481360
else:

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def apply(self,
402402
scheme = layer.scheme
403403
if scheme is None:
404404
raise ValueError("A scheme must be defined for each layer")
405+
405406
return scheme.apply_weights(layer, x, bias=bias)
406407

407408

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 156 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,104 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from functools import cache
45
from typing import Any, Callable, Optional
56

67
import torch
78
import torch.nn.functional as F
89

9-
from vllm.logger import init_logger
10+
from vllm import envs
1011
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
1112
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
1213
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
1314
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1415
PackedvLLMParameter)
1516
from vllm.platforms import current_platform
1617

17-
logger = init_logger(__name__)
18+
19+
@cache
20+
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
21+
return current_platform.is_rocm() \
22+
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
23+
and envs.VLLM_ROCM_USE_AITER
24+
25+
26+
try:
27+
from aiter.ops.shuffle import shuffle_weight
28+
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
29+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
30+
31+
from vllm.utils import direct_register_custom_op
32+
if is_rocm_aiter_fp4_asm_gemm_enabled():
33+
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
34+
35+
def gemm_with_dynamic_quant(
36+
x: torch.Tensor,
37+
weight: torch.Tensor,
38+
weight_scale: torch.Tensor,
39+
rocm_use_aiter_fp4_asm_gemm: bool = False,
40+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
41+
x_scales: Optional[torch.Tensor] = None,
42+
) -> torch.Tensor:
43+
M = x.shape[0]
44+
if rocm_use_aiter_fp4_asm_gemm:
45+
if x_scales is None:
46+
# use hip quant kernel for performance
47+
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
48+
else:
49+
x_q = x
50+
x_s = x_scales
51+
52+
# 32 alignment is enough for dim0 padding of output for
53+
# gemm_a4w4 kernel
54+
y = torch.empty((M + 31) // 32 * 32,
55+
weight.shape[0],
56+
device=x_q.device,
57+
dtype=out_dtype)
58+
59+
gemm_a4w4(x_q,
60+
weight,
61+
x_s,
62+
weight_scale.view(x_s.dtype),
63+
y,
64+
bpreshuffle=True)
65+
return y[:M]
66+
else:
67+
if x_scales is None:
68+
x_q, x_s = dynamic_mxfp4_quant(x)
69+
else:
70+
x_q = x
71+
x_s = x_scales
72+
y = torch.empty(x_q.shape[0],
73+
weight.shape[0],
74+
device=x_q.device,
75+
dtype=out_dtype)
76+
77+
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
78+
return y
79+
80+
def gemm_with_dynamic_quant_fake(
81+
x: torch.Tensor,
82+
weight: torch.Tensor,
83+
weight_scale: torch.Tensor,
84+
x_scales: torch.Tensor = None,
85+
rocm_use_aiter_fp4_asm_gemm: bool = False,
86+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
87+
) -> torch.Tensor:
88+
return torch.empty((*x.shape[:-1], weight.shape[0]),
89+
dtype=out_dtype,
90+
device=x.device)
91+
92+
direct_register_custom_op(
93+
op_name="gemm_with_dynamic_quant",
94+
op_func=gemm_with_dynamic_quant,
95+
mutates_args=[],
96+
fake_impl=gemm_with_dynamic_quant_fake,
97+
dispatch_key=current_platform.dispatch_key,
98+
)
99+
100+
except ImportError:
101+
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
18102

19103
__all__ = ["QuarkW4A4MXFP4"]
20104

@@ -27,29 +111,15 @@ def __init__(self, weight_quant_spec: dict[str, Any],
27111
self.qscheme = "per_group"
28112
self.weight_quant_spec = weight_quant_spec
29113
self.input_quant_spec = input_quant_spec
30-
31-
self.static_input_scales = not input_quant_spec.get("is_dynamic")
32-
33-
if self.static_input_scales:
114+
self.emulate = not current_platform.supports_mx()
115+
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
116+
if not self.emulate and (dynamic_mxfp4_quant is None
117+
or gemm_afp4wfp4 is None):
118+
# Currently need these kernels if not emulating
34119
raise NotImplementedError(
35-
"QuarkW4A4MXFP4 with static input scales is currently not "
36-
"implemented. Please open an issue.")
37-
38-
if not current_platform.supports_mx():
39-
self.emulate = True
40-
logger.warning_once(
41-
"The current platform does not support native MXFP4 "
42-
"computation. Simulated weight dequantization and activation "
43-
"QDQ (quantize and dequantize) will be used, with the linear "
44-
"layers computed in high precision.")
45-
else:
46-
self.emulate = True
47-
logger.warning_once(
48-
"The current platform supports native MXFP4 "
49-
"computation, but kernels are not yet integrated in vLLM. "
50-
"Simulated weight dequantization and activation "
51-
"QDQ (quantize and dequantize) will be used, with the linear "
52-
"layers computed in high precision.")
120+
f"{self.__class__.__name__} requires AITER to be installed "
121+
"for non-emulation mode! Please refer to "
122+
"https://github.com/ROCm/aiter for installation details.")
53123

54124
@classmethod
55125
def get_min_capability(cls) -> int:
@@ -58,8 +128,65 @@ def get_min_capability(cls) -> int:
58128
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
59129
layer.weight = torch.nn.Parameter(layer.weight.data,
60130
requires_grad=False)
61-
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
62-
requires_grad=False)
131+
132+
if self.emulate:
133+
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
134+
requires_grad=False)
135+
try:
136+
from quark.torch.export.nn.modules import realquantizer
137+
from quark.torch.quantization.config.config import (
138+
QuantizationSpec)
139+
except ImportError as err:
140+
raise ImportError(
141+
"The package `amd-quark` is required to use AMD Quark "
142+
"MX-FP4 models. Please install it with `pip install "
143+
"amd-quark`.") from err
144+
145+
weight_quant_spec = QuantizationSpec.from_dict(
146+
self.weight_quant_spec)
147+
148+
weight_quantizer = realquantizer.get_real_quantizer(
149+
qspec=weight_quant_spec,
150+
quantizer=None,
151+
real_quantized=True,
152+
reorder=False,
153+
float_dtype=self.out_dtype,
154+
scale_shape=layer.weight_scale.shape,
155+
zero_point_shape=None,
156+
)
157+
weight_quantizer.scale.data = layer.weight_scale.data
158+
159+
layer.weight = torch.nn.Parameter(
160+
weight_quantizer(layer.weight.data).to(self.out_dtype),
161+
requires_grad=False,
162+
)
163+
layer.weight_scale = None
164+
165+
# This call is necessary to release the scales memory.
166+
torch.cuda.empty_cache()
167+
else:
168+
if self.rocm_use_aiter_fp4_asm_gemm:
169+
# shuffle weight scale
170+
weight_scale_shuffle = layer.weight_scale.data
171+
sm, sn = weight_scale_shuffle.shape
172+
weight_scale_shuffle = weight_scale_shuffle.view(
173+
sm // 32, 2, 16, sn // 8, 2, 4, 1)
174+
weight_scale_shuffle = weight_scale_shuffle.permute(
175+
0, 3, 5, 2, 4, 1, 6).contiguous()
176+
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
177+
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
178+
requires_grad=False)
179+
180+
# shuffle weight
181+
weight_shuffle = layer.weight.data
182+
weight_shuffle = shuffle_weight(weight_shuffle,
183+
layout=(16, 16))
184+
layer.weight = torch.nn.Parameter(weight_shuffle,
185+
requires_grad=False)
186+
else:
187+
layer.weight_scale = torch.nn.Parameter(
188+
layer.weight_scale.data.T.contiguous(),
189+
requires_grad=False)
63190

64191
def create_weights(self, layer: torch.nn.Module,
65192
output_partition_sizes: list[int],
@@ -104,9 +231,9 @@ def apply_weights(self,
104231

105232
if self.emulate:
106233
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
107-
108234
x = quant_dequant_mxfp4(x)
109-
110235
return F.linear(x, dq_w, bias)
111236
else:
112-
raise NotImplementedError()
237+
return torch.ops.vllm.gemm_with_dynamic_quant(
238+
x, layer.weight, layer.weight_scale,
239+
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)

0 commit comments

Comments
 (0)