Skip to content

Commit 569b012

Browse files
tzj-fxzLeiWang1999
andauthored
Low-bit kernels fix and implementation (#704)
* [MXFP4] Dequantize FP4 kernel example, MX scale todo * [BugFix] Fix the bug of fp4&fp16 exponential bias * [MXFP4] Add group scale factor for BF16xMXFP4 gemm * [Lint] * [Test] Add test script for BF16xMXFP4 gemm * [Lint] * [BugFix] Fix the shape of scale tensor * Update example_dequant_gemm_fp4_hopper.py --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
1 parent 376ba9e commit 569b012

File tree

3 files changed

+445
-10
lines changed

3 files changed

+445
-10
lines changed

examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
1212
assert dtype == "float16"
1313
assert val.dtype == "uint8"
1414
# e_f4 == 0 -> e_f16 = 0
15-
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
16-
# s1e2n1
15+
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
16+
# s1e2m1
1717
mask = tir.const((1 << nbit) - 1, "uint16")
1818
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
1919
s = f4 >> tir.const(3, "uint16")
20-
e_f4 = f4 & tir.const(7, "uint16")
21-
e_f16 = e_f4 | tir.const(8, "uint16")
22-
val_f16 = tir.reinterpret(
23-
"float16",
24-
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
20+
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
21+
e_f16 = e_f4 + tir.const(14, "uint16")
22+
m_f4 = f4 & tir.const(1, "uint16")
23+
m_f16 = m_f4
24+
val_f16 = tir.reinterpret("float16",
25+
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
26+
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
2527
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
2628
return val_f16
2729

@@ -39,9 +41,11 @@ def _convert(val, pos):
3941
mask = (1 << 4) - 1
4042
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
4143
s = f4 >> 3
42-
e_f4 = f4 & 7
43-
e_f16 = e_f4 | 8
44-
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
44+
e_f4 = (f4 & 6) >> 1
45+
e_f16 = e_f4 + 14
46+
m_f4 = f4 & 1
47+
m_f16 = m_f4
48+
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
4549
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
4650
return lower_16_bits.view(torch.float16)
4751

0 commit comments

Comments
 (0)