diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d8b41f1ed8f80..e414898853636 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -81,7 +81,8 @@ def is_dim_blocked(dim, shape, group_shape): if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ input_group_shape == (1, weight_group_shape[1]): - return apply_w8a8_block_fp8_linear(input, weight, weight_group_shape, + return apply_w8a8_block_fp8_linear(input, weight, + list(weight_group_shape), weight_scale) else: # Despite having linear in the it doesn't conform to @@ -159,12 +160,18 @@ def block_quant_to_tensor_quant( # Quantize to fp8 assuming once scale per group of elements with shape -# group_shape +# group_shape, example group shapes: +# * (-1, -1) for per-tensor quantization +# * (1, -1) for per-row quantization +# * (-1, 1) for per-column quantization +# * (128, 128) for 128x128 deepseek style block quantization +# * (1, 128) for deepseek style activation quantization +# (i.e. per-token-per-group) def fp8_quantize( x: torch.Tensor, group_shape: Tuple[int, int], dtype: Optional[torch.dtype] = None, -): +) -> Tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) if dtype is None: dtype = (torch.float8_e4m3fnuz