11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ from functools import cache
45from typing import Any , Callable , Optional
56
67import torch
78import torch .nn .functional as F
89
9- from vllm . logger import init_logger
10+ from vllm import envs
1011from vllm .model_executor .layers .quantization .quark .schemes import QuarkScheme
1112from vllm .model_executor .layers .quantization .utils .mxfp4_utils import (
1213 OCP_MX_BLOCK_SIZE , dequant_mxfp4 , quant_dequant_mxfp4 )
1314from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
1415 PackedvLLMParameter )
1516from vllm .platforms import current_platform
1617
17- logger = init_logger (__name__ )
18+
19+ @cache
20+ def is_rocm_aiter_fp4_asm_gemm_enabled () -> bool :
21+ return current_platform .is_rocm () \
22+ and envs .VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
23+ and envs .VLLM_ROCM_USE_AITER
24+
25+
26+ try :
27+ from aiter .ops .shuffle import shuffle_weight
28+ from aiter .ops .triton .gemm_afp4wfp4 import gemm_afp4wfp4
29+ from aiter .ops .triton .quant import dynamic_mxfp4_quant
30+
31+ from vllm .utils import direct_register_custom_op
32+ if is_rocm_aiter_fp4_asm_gemm_enabled ():
33+ from aiter import gemm_a4w4 , per_1x32_f4_quant_hip
34+
35+ def gemm_with_dynamic_quant (
36+ x : torch .Tensor ,
37+ weight : torch .Tensor ,
38+ weight_scale : torch .Tensor ,
39+ rocm_use_aiter_fp4_asm_gemm : bool = False ,
40+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
41+ x_scales : Optional [torch .Tensor ] = None ,
42+ ) -> torch .Tensor :
43+ M = x .shape [0 ]
44+ if rocm_use_aiter_fp4_asm_gemm :
45+ if x_scales is None :
46+ # use hip quant kernel for performance
47+ x_q , x_s = per_1x32_f4_quant_hip (x , shuffle = True )
48+ else :
49+ x_q = x
50+ x_s = x_scales
51+
52+ # 32 alignment is enough for dim0 padding of output for
53+ # gemm_a4w4 kernel
54+ y = torch .empty ((M + 31 ) // 32 * 32 ,
55+ weight .shape [0 ],
56+ device = x_q .device ,
57+ dtype = out_dtype )
58+
59+ gemm_a4w4 (x_q ,
60+ weight ,
61+ x_s ,
62+ weight_scale .view (x_s .dtype ),
63+ y ,
64+ bpreshuffle = True )
65+ return y [:M ]
66+ else :
67+ if x_scales is None :
68+ x_q , x_s = dynamic_mxfp4_quant (x )
69+ else :
70+ x_q = x
71+ x_s = x_scales
72+ y = torch .empty (x_q .shape [0 ],
73+ weight .shape [0 ],
74+ device = x_q .device ,
75+ dtype = out_dtype )
76+
77+ gemm_afp4wfp4 (x_q , weight , x_s , weight_scale .T , out_dtype , y )
78+ return y
79+
80+ def gemm_with_dynamic_quant_fake (
81+ x : torch .Tensor ,
82+ weight : torch .Tensor ,
83+ weight_scale : torch .Tensor ,
84+ x_scales : torch .Tensor = None ,
85+ rocm_use_aiter_fp4_asm_gemm : bool = False ,
86+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
87+ ) -> torch .Tensor :
88+ return torch .empty ((* x .shape [:- 1 ], weight .shape [0 ]),
89+ dtype = out_dtype ,
90+ device = x .device )
91+
92+ direct_register_custom_op (
93+ op_name = "gemm_with_dynamic_quant" ,
94+ op_func = gemm_with_dynamic_quant ,
95+ mutates_args = [],
96+ fake_impl = gemm_with_dynamic_quant_fake ,
97+ dispatch_key = current_platform .dispatch_key ,
98+ )
99+
100+ except ImportError :
101+ dynamic_mxfp4_quant = gemm_afp4wfp4 = None
18102
19103__all__ = ["QuarkW4A4MXFP4" ]
20104
@@ -27,29 +111,15 @@ def __init__(self, weight_quant_spec: dict[str, Any],
27111 self .qscheme = "per_group"
28112 self .weight_quant_spec = weight_quant_spec
29113 self .input_quant_spec = input_quant_spec
30-
31- self .static_input_scales = not input_quant_spec .get ("is_dynamic" )
32-
33- if self .static_input_scales :
114+ self .emulate = not current_platform .supports_mx ()
115+ self .rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled ()
116+ if not self .emulate and (dynamic_mxfp4_quant is None
117+ or gemm_afp4wfp4 is None ):
118+ # Currently need these kernels if not emulating
34119 raise NotImplementedError (
35- "QuarkW4A4MXFP4 with static input scales is currently not "
36- "implemented. Please open an issue." )
37-
38- if not current_platform .supports_mx ():
39- self .emulate = True
40- logger .warning_once (
41- "The current platform does not support native MXFP4 "
42- "computation. Simulated weight dequantization and activation "
43- "QDQ (quantize and dequantize) will be used, with the linear "
44- "layers computed in high precision." )
45- else :
46- self .emulate = True
47- logger .warning_once (
48- "The current platform supports native MXFP4 "
49- "computation, but kernels are not yet integrated in vLLM. "
50- "Simulated weight dequantization and activation "
51- "QDQ (quantize and dequantize) will be used, with the linear "
52- "layers computed in high precision." )
120+ f"{ self .__class__ .__name__ } requires AITER to be installed "
121+ "for non-emulation mode! Please refer to "
122+ "https://github.com/ROCm/aiter for installation details." )
53123
54124 @classmethod
55125 def get_min_capability (cls ) -> int :
@@ -58,8 +128,65 @@ def get_min_capability(cls) -> int:
58128 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
59129 layer .weight = torch .nn .Parameter (layer .weight .data ,
60130 requires_grad = False )
61- layer .weight_scale = torch .nn .Parameter (layer .weight_scale .data ,
62- requires_grad = False )
131+
132+ if self .emulate :
133+ layer .weight_scale = torch .nn .Parameter (layer .weight_scale .data ,
134+ requires_grad = False )
135+ try :
136+ from quark .torch .export .nn .modules import realquantizer
137+ from quark .torch .quantization .config .config import (
138+ QuantizationSpec )
139+ except ImportError as err :
140+ raise ImportError (
141+ "The package `amd-quark` is required to use AMD Quark "
142+ "MX-FP4 models. Please install it with `pip install "
143+ "amd-quark`." ) from err
144+
145+ weight_quant_spec = QuantizationSpec .from_dict (
146+ self .weight_quant_spec )
147+
148+ weight_quantizer = realquantizer .get_real_quantizer (
149+ qspec = weight_quant_spec ,
150+ quantizer = None ,
151+ real_quantized = True ,
152+ reorder = False ,
153+ float_dtype = self .out_dtype ,
154+ scale_shape = layer .weight_scale .shape ,
155+ zero_point_shape = None ,
156+ )
157+ weight_quantizer .scale .data = layer .weight_scale .data
158+
159+ layer .weight = torch .nn .Parameter (
160+ weight_quantizer (layer .weight .data ).to (self .out_dtype ),
161+ requires_grad = False ,
162+ )
163+ layer .weight_scale = None
164+
165+ # This call is necessary to release the scales memory.
166+ torch .cuda .empty_cache ()
167+ else :
168+ if self .rocm_use_aiter_fp4_asm_gemm :
169+ # shuffle weight scale
170+ weight_scale_shuffle = layer .weight_scale .data
171+ sm , sn = weight_scale_shuffle .shape
172+ weight_scale_shuffle = weight_scale_shuffle .view (
173+ sm // 32 , 2 , 16 , sn // 8 , 2 , 4 , 1 )
174+ weight_scale_shuffle = weight_scale_shuffle .permute (
175+ 0 , 3 , 5 , 2 , 4 , 1 , 6 ).contiguous ()
176+ weight_scale_shuffle = weight_scale_shuffle .view (sm , sn )
177+ layer .weight_scale = torch .nn .Parameter (weight_scale_shuffle ,
178+ requires_grad = False )
179+
180+ # shuffle weight
181+ weight_shuffle = layer .weight .data
182+ weight_shuffle = shuffle_weight (weight_shuffle ,
183+ layout = (16 , 16 ))
184+ layer .weight = torch .nn .Parameter (weight_shuffle ,
185+ requires_grad = False )
186+ else :
187+ layer .weight_scale = torch .nn .Parameter (
188+ layer .weight_scale .data .T .contiguous (),
189+ requires_grad = False )
63190
64191 def create_weights (self , layer : torch .nn .Module ,
65192 output_partition_sizes : list [int ],
@@ -104,9 +231,9 @@ def apply_weights(self,
104231
105232 if self .emulate :
106233 dq_w = dequant_mxfp4 (layer .weight , layer .weight_scale , x .dtype )
107-
108234 x = quant_dequant_mxfp4 (x )
109-
110235 return F .linear (x , dq_w , bias )
111236 else :
112- raise NotImplementedError ()
237+ return torch .ops .vllm .gemm_with_dynamic_quant (
238+ x , layer .weight , layer .weight_scale ,
239+ self .rocm_use_aiter_fp4_asm_gemm , self .out_dtype )
0 commit comments