Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent 3d12a04 commit f51cbe0
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def is_dim_blocked(dim, shape, group_shape):
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight, weight_group_shape,
return apply_w8a8_block_fp8_linear(input, weight,
list(weight_group_shape),
weight_scale)
else:
# Despite having linear in the it doesn't conform to
Expand Down Expand Up @@ -159,12 +160,18 @@ def block_quant_to_tensor_quant(


# Quantize to fp8 assuming once scale per group of elements with shape
# group_shape
# group_shape, example group shapes:
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
# * (-1, 1) for per-column quantization
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
def fp8_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
if dtype is None:
dtype = (torch.float8_e4m3fnuz
Expand Down

0 comments on commit f51cbe0

Please sign in to comment.