From a09b65f4b251db6d5ee7f285d74808b08d4940d5 Mon Sep 17 00:00:00 2001 From: "Jae H. Yoo" Date: Mon, 23 Dec 2024 14:35:00 -0800 Subject: [PATCH] PR #19096: Add F4E2M1FN and F8E8M0FNU types Imported from GitHub PR https://github.com/openxla/xla/pull/19096 This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](https://github.com/openxla/xla/discussions/18085)), such as MXFP4. ```c... PiperOrigin-RevId: 709153611 --- ml_dtypes/include/mxfloat.h | 1 + 1 file changed, 1 insertion(+) diff --git a/ml_dtypes/include/mxfloat.h b/ml_dtypes/include/mxfloat.h index 6bdeea6a..252e8351 100644 --- a/ml_dtypes/include/mxfloat.h +++ b/ml_dtypes/include/mxfloat.h @@ -339,6 +339,7 @@ namespace Eigen { namespace numext { #define MXFLOAT_EIGEN_SIGNBIT_IMPL(Type) \ + template <> \ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Type signbit(const Type& x) { \ int8_t t = bit_cast(x) << (8 - Type::kBits); \ return bit_cast(t >> 7); \