Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,13 @@ def check_model(model):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, scheme)
if isinstance(qkv_proj.scheme, scheme) or isinstance(
qkv_proj.scheme, CompressedTensorsW4A16Fp4
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")

assert qkv_proj.scheme.group_size == 16

llm.apply_model(check_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,14 @@ def _get_scheme_from_parts(

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Fp4()
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4.")
return CompressedTensorsW4A16Fp4(
has_input_global_scale=True)

if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):

def __init__(self):
def __init__(self, has_input_global_scale: bool = False):
self.has_input_global_scale = has_input_global_scale
self.group_size = 16

@classmethod
Expand Down Expand Up @@ -64,6 +65,13 @@ def create_weights(self, layer: torch.nn.Module,

layer.register_parameter("weight_scale", weight_scale)

if self.has_input_global_scale:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)

def process_weights_after_loading(self, layer) -> None:
# Process parameters for marlin repacking

Expand All @@ -77,6 +85,10 @@ def process_weights_after_loading(self, layer) -> None:
requires_grad=False)
del layer.weight_global_scale

if self.has_input_global_scale:
layer.input_global_scale = torch.nn.Parameter(
layer.input_global_scale.data, requires_grad=False)

prepare_fp4_layer_for_marlin(layer)

def apply_weights(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand All @@ -21,53 +19,23 @@
__all__ = ["CompressedTensorsW4A4Fp4"]


def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)


class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):

def __init__(self):
self.group_size = 16
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
logger.warning("Current platform does not support cutlass NVFP4."
" Running emulations.")

@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80

def run_nvfp4_emulations(self, x: torch.Tensor, layer):
x_m, x_k = x.shape
output_dtype = x.dtype

# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
self.group_size)
return 100
Comment on lines 28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing get_min_capability from 80 to 100 aligns with the PR's goal.


# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_global_scale
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device, self.group_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out
@classmethod
def cutlass_fp4_supported(cls) -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
)
return cutlass_scaled_mm_supports_fp4(capability)
Comment on lines +31 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the cutlass_fp4_supported check to a class method is a good refactoring.


def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
Expand Down Expand Up @@ -152,27 +120,24 @@ def process_weights_after_loading(self, layer) -> None:
# required by cutlass kernel; need Parameter, not ModelWeightParameter
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)

if self.cutlass_nvfp4_supported:
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.cutlass_nvfp4_supported:
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)

out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)
return self.run_nvfp4_emulations(x, layer)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size):
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)


def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale_swizzled: torch.Tensor,
weight_global_scale: torch.Tensor):
group_size = 16
x_m, x_k = x.shape
output_dtype = x.dtype

# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size)

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = weight.data.view(torch.uint8)
w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data,
weight_global_scale, output_dtype, x.device,
group_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out
Comment on lines +107 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the run_nvfp4_emulations function here from the scheme class is a good separation of concerns.