diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d68aa22bed0c..516bf4513816 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -667,7 +667,13 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, scheme) + if isinstance(qkv_proj.scheme, scheme) or isinstance( + qkv_proj.scheme, CompressedTensorsW4A16Fp4 + ) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + assert True + else: + raise AssertionError("FP4 Scheme Mismatch") + assert qkv_proj.scheme.group_size == 16 llm.apply_model(check_model) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 28c62fc5e58b..e5702c871cc9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -374,7 +374,14 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4Fp4() + if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + return CompressedTensorsW4A4Fp4() + else: + logger.warning_once( + "Current platform does not support cutlass NVFP4." + " Running CompressedTensorsW4A16Fp4.") + return CompressedTensorsW4A16Fp4( + has_input_global_scale=True) if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index 8202ce951496..96dccf04d490 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -18,7 +18,8 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): - def __init__(self): + def __init__(self, has_input_global_scale: bool = False): + self.has_input_global_scale = has_input_global_scale self.group_size = 16 @classmethod @@ -64,6 +65,13 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_scale", weight_scale) + if self.has_input_global_scale: + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_global_scale", input_global_scale) + def process_weights_after_loading(self, layer) -> None: # Process parameters for marlin repacking @@ -77,6 +85,10 @@ def process_weights_after_loading(self, layer) -> None: requires_grad=False) del layer.weight_global_scale + if self.has_input_global_scale: + layer.input_global_scale = torch.nn.Parameter( + layer.input_global_scale.data, requires_grad=False) + prepare_fp4_layer_for_marlin(layer) def apply_weights(self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 9899db3243a4..32718972a627 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -9,8 +9,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - dequantize_to_dtype, ref_nvfp4_quant) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -21,53 +19,23 @@ __all__ = ["CompressedTensorsW4A4Fp4"] -def cutlass_fp4_supported() -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int() - return cutlass_scaled_mm_supports_fp4(capability) - - class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): self.group_size = 16 - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - if not self.cutlass_nvfp4_supported: - logger.warning("Current platform does not support cutlass NVFP4." - " Running emulations.") @classmethod def get_min_capability(cls) -> int: - # dont restrict as emulations - return 80 - - def run_nvfp4_emulations(self, x: torch.Tensor, layer): - x_m, x_k = x.shape - output_dtype = x.dtype - - # quantize input to (FP4 and interleaved block scale) - x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale, - self.group_size) + return 100 - # dequantize input - x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) - x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale - x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) - del x_fp4, x_blockscale - - # dequantize weight - w_fp4 = layer.weight.data.view(torch.uint8) - w_blockscale = layer.weight_scale_swizzled.data - w_global_scale = layer.weight_global_scale - w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - output_dtype, x.device, self.group_size) - - # matmul - out = torch.matmul(x_dq, w_dq.t()) - del w_dq, x_dq - return out + @classmethod + def cutlass_fp4_supported(cls) -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501 + ) + return cutlass_scaled_mm_supports_fp4(capability) def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], @@ -152,27 +120,24 @@ def process_weights_after_loading(self, layer) -> None: # required by cutlass kernel; need Parameter, not ModelWeightParameter layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) - if self.cutlass_nvfp4_supported: - layer.alpha = Parameter(layer.input_global_scale * - layer.weight_global_scale, - requires_grad=False) + layer.alpha = Parameter(layer.input_global_scale * + layer.weight_global_scale, + requires_grad=False) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.cutlass_nvfp4_supported: - output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, - 1 / layer.alpha, output_dtype) - if bias is not None: - out = out + bias - return out.view(*output_shape) - return self.run_nvfp4_emulations(x, layer) + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, + 1 / layer.alpha, output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index c4ef3ce24c03..d5ce6d7ad757 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size): clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) # both outputs are float32 return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale_swizzled: torch.Tensor, + weight_global_scale: torch.Tensor): + group_size = 16 + x_m, x_k = x.shape + output_dtype = x.dtype + + # quantize input to (FP4 and interleaved block scale) + x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size) + x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = weight.data.view(torch.uint8) + w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data, + weight_global_scale, output_dtype, x.device, + group_size) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out