|
11 | 11 | TopKWeightAndReduceDelegate) |
12 | 12 | from vllm.model_executor.layers.fused_moe.utils import _resize_cache |
13 | 13 | from vllm.triton_utils import tl, triton |
14 | | -from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked |
| 14 | +from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, |
| 15 | + is_blackwell_deep_gemm_used) |
15 | 16 |
|
16 | 17 | logger = init_logger(__name__) |
17 | 18 |
|
@@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm( |
50 | 51 | eps: tl.constexpr, |
51 | 52 | fp8_min: tl.constexpr, |
52 | 53 | fp8_max: tl.constexpr, |
| 54 | + use_ue8m0: tl.constexpr, |
53 | 55 |
|
54 | 56 | # Meta --------------------------------------------------------------- |
55 | 57 | BLOCK: tl.constexpr, |
@@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm( |
92 | 94 | y = x * y2 |
93 | 95 |
|
94 | 96 | _absmax = tl.maximum(tl.max(tl.abs(y)), eps) |
95 | | - y_s = _absmax / fp8_max |
| 97 | + scale_raw = _absmax / fp8_max |
| 98 | + y_s = tl.math.exp2(tl.ceil( |
| 99 | + tl.log2(scale_raw))) if use_ue8m0 else scale_raw |
96 | 100 | y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) |
97 | 101 |
|
98 | 102 | tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) |
@@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm( |
174 | 178 | eps, |
175 | 179 | fp8_min, |
176 | 180 | fp8_max, |
| 181 | + is_blackwell_deep_gemm_used(), |
177 | 182 | BLOCK=group_size, |
178 | 183 | num_warps=4, |
179 | 184 | ) |
@@ -290,14 +295,10 @@ def apply( |
290 | 295 | # may lead to better performance. |
291 | 296 | expected_m = max_num_tokens |
292 | 297 | fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), |
293 | | - out=workspace1, |
294 | | - masked_m=expert_num_tokens, |
295 | | - expected_m=expected_m) |
| 298 | + workspace1, expert_num_tokens, expected_m) |
296 | 299 |
|
297 | 300 | a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, |
298 | 301 | expert_num_tokens) |
299 | 302 |
|
300 | | - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), |
301 | | - out=output, |
302 | | - masked_m=expert_num_tokens, |
303 | | - expected_m=expected_m) |
| 303 | + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, |
| 304 | + expert_num_tokens, expected_m) |
0 commit comments