2020from vllm .platforms import current_platform
2121from vllm .triton_utils import tl , triton
2222from vllm .utils import cdiv , direct_register_custom_op , has_deep_gemm
23+ from vllm .utils .deep_gemm import is_blackwell_deep_gemm_used
2324
2425logger = init_logger (__name__ )
2526
@@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
256257 # Information for float8
257258 fp8_min ,
258259 fp8_max ,
260+ use_ue8m0 : tl .constexpr ,
259261 # Meta-parameters
260262 BLOCK : tl .constexpr ,
261263):
@@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
285287 y = tl .load (y_ptr + cols , mask = mask , other = 0.0 ).to (tl .float32 )
286288 # Quant
287289 _absmax = tl .maximum (tl .max (tl .abs (y )), eps )
288- y_s = _absmax / fp8_max
290+ scale_raw = _absmax / fp8_max
291+ y_s = tl .math .exp2 (tl .ceil (tl .log2 (scale_raw ))) if use_ue8m0 else scale_raw
289292 y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
290293
291294 tl .store (y_q_ptr + cols , y_q , mask = mask )
@@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
309312 # Information for float8
310313 fp8_min ,
311314 fp8_max ,
315+ use_ue8m0 : tl .constexpr ,
312316 # Meta-parameters
313317 BLOCK : tl .constexpr ,
314318):
@@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
347351 y = tl .load (y_ptr + cols , mask = mask , other = 0.0 ).to (tl .float32 )
348352 # Quant
349353 _absmax = tl .maximum (tl .max (tl .abs (y )), eps )
350- y_s = _absmax / fp8_max
354+ scale_raw = _absmax / fp8_max
355+ y_s = tl .math .exp2 (tl .ceil (tl .log2 (scale_raw ))) if use_ue8m0 else scale_raw
351356 y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
352357
353358 tl .store (y_q_ptr + cols , y_q , mask = mask )
@@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
373378 is supported for now.
374379 column_major_scales: Outputs scales in column major.
375380 out_q: Optional output tensor. If not provided, function will create.
376- Returns:
377381 tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
378382 scaling factor for quantization.
383+ Returns:
384+ tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
385+ scaling factor.
379386 """
380387 dtype = current_platform .fp8_dtype () if dtype is None else dtype
381388 assert (x .shape [- 1 ] % group_size == 0 ), (
@@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
418425 eps ,
419426 fp8_min = fp8_min ,
420427 fp8_max = fp8_max ,
428+ use_ue8m0 = is_blackwell_deep_gemm_used (),
421429 BLOCK = BLOCK ,
422430 num_warps = num_warps ,
423431 num_stages = num_stages ,
@@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
433441 eps ,
434442 fp8_min = fp8_min ,
435443 fp8_max = fp8_max ,
444+ use_ue8m0 = is_blackwell_deep_gemm_used (),
436445 BLOCK = BLOCK ,
437446 num_warps = num_warps ,
438447 num_stages = num_stages ,
0 commit comments