diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b16fef871419..8ebe694eefd0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1276,7 +1276,7 @@ def scaled_fp8_quant( torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) + assert (scale.numel() == 1 and num_token_padding is None) torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale