Skip to content
24 changes: 14 additions & 10 deletions examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
# s1e2n1
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16")
m_f16 = m_f4
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16

Expand All @@ -39,9 +41,11 @@ def _convert(val, pos):
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = f4 & 7
e_f16 = e_f4 | 8
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 14
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)

Expand Down
Loading