Skip to content

Commit 18f2f19

Browse files
Merge pull request #26525 from wenscarl:e2m1fn
PiperOrigin-RevId: 735457804
2 parents 73d20cd + ccbe9f7 commit 18f2f19

11 files changed

+70
-1
lines changed

jax/_src/dtypes.py

+20
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def type(self) -> type: ...
109109
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
110110
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
111111

112+
# fp4 support
113+
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
114+
float4_e2m1fn: type[np.generic] | None = None
115+
116+
_float4_e2m1fn_dtype: np.dtype | None = None
117+
112118
def supports_inf(dtype: DTypeLike) -> bool:
113119
"""Return true if the dtype supports infinity, else return False."""
114120
typ = np.dtype(dtype).type
@@ -144,6 +150,8 @@ def supports_inf(dtype: DTypeLike) -> bool:
144150
_float8_e5m2fnuz_dtype,
145151
]
146152

153+
_float4_dtypes: list[np.dtype] = []
154+
147155
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
148156
if hasattr(ml_dtypes, "float8_e4m3"):
149157
float8_e4m3 = ml_dtypes.float8_e4m3
@@ -163,6 +171,12 @@ def supports_inf(dtype: DTypeLike) -> bool:
163171
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
164172
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
165173
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
174+
if hasattr(ml_dtypes, "float4_e2m1fn"):
175+
float4_e2m1fn = ml_dtypes.float4_e2m1fn
176+
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
177+
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
178+
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
179+
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
166180

167181
# 2-bit integer support
168182
int2: type[np.generic] | None = None
@@ -716,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
716730
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
717731
"implicit promotion. If you'd like your inputs to be promoted to another type, "
718732
"you can do so explicitly using e.g. x.astype('float32')")
733+
elif any(n in _float4_dtypes for n in nodes):
734+
msg = (
735+
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
736+
"promotion path. To avoid unintended promotion, 4-bit floats do not support "
737+
"implicit promotion. If you'd like your inputs to be promoted to another type, "
738+
"you can do so explicitly using e.g. x.astype('float32')")
719739
elif any(n in _intn_dtypes for n in nodes):
720740
msg = (
721741
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "

jax/_src/export/serialization.fbs

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ enum DType: byte {
7575
f8_e5m2 = 20,
7676
f8_e5m2fnuz = 21,
7777
f8_e8m0fnu = 25,
78+
f4_e2m1fn = 26,
7879
}
7980

8081
table AbstractValue {

jax/_src/export/serialization.py

+2
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
365365
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
366366
if dtypes._float8_e8m0fnu_dtype is not None:
367367
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
368+
if dtypes._float4_e2m1fn_dtype is not None:
369+
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
368370
_dtype_kind_to_dtype = {
369371
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
370372
}

jax/_src/export/serialization_generated.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DType(object):
6262
f8_e5m2fnuz = 21
6363
f0 = 22
6464
f8_e8m0fnu = 25
65+
f4_e2m1fn = 26
6566

6667

6768
class ShardingKind(object):

jax/_src/interpreters/mlir.py

+3
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def _is_ir_values(x: IrValues) -> bool:
199199
if dtypes.float8_e8m0fnu is not None:
200200
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
201201

202+
if dtypes.float4_e2m1fn is not None:
203+
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
204+
202205
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
203206
if isinstance(dtype, core.bint):
204207
# TODO Support different-size underlying dtypes to take advantage of the

jax/_src/numpy/scalar_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
9393
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
9494
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
9595
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
96+
if dtypes.float4_e2m1fn is not None:
97+
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
9698
bfloat16 = _make_scalar_type(dtypes.bfloat16)
9799
float16 = _make_scalar_type(np.float16)
98100
float32 = single = _make_scalar_type(np.float32)

jax/_src/public_test_util.py

+5
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def default_tolerance():
100100
if _dtypes.float8_e8m0fnu is not None:
101101
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
102102
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
103+
if _dtypes.float4_e2m1fn is not None:
104+
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
105+
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
103106

104107
def is_python_scalar(val):
105108
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
124127
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
125128
if _dtypes.float8_e8m0fnu is not None:
126129
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
130+
if _dtypes.float4_e2m1fn is not None:
131+
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
127132

128133
def maybe_upcast(x):
129134
if x.dtype in custom_float_dtypes:

jax/_src/test_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,8 @@ def custom_floats(self):
16401640
float_dtypes += [_dtypes.float8_e4m3]
16411641
if _dtypes.float8_e8m0fnu is not None:
16421642
float_dtypes += [_dtypes.float8_e8m0fnu]
1643+
if _dtypes.float4_e2m1fn is not None:
1644+
float_dtypes += [_dtypes.float4_e2m1fn]
16431645
return self.supported(float_dtypes)
16441646

16451647
@_cached_property

jax/numpy/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@
310310
float8_e3m4 as float8_e3m4,
311311
float8_e4m3 as float8_e4m3,
312312
float8_e8m0fnu as float8_e8m0fnu,
313+
float4_e2m1fn as float4_e2m1fn,
313314
)
314315
except ImportError:
315316
pass

tests/dtypes_test.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@
7373
float_dtypes += fp8_dtypes
7474
custom_float_dtypes += fp8_dtypes
7575

76+
fp4_dtypes = []
77+
if dtypes.float4_e2m1fn is not None:
78+
fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)]
79+
float_dtypes += fp4_dtypes
80+
custom_float_dtypes += fp4_dtypes
81+
7682
complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]
7783

7884

@@ -238,6 +244,8 @@ def testPromoteDtypesStandard(self):
238244
continue
239245
if t1 in intn_dtypes:
240246
continue
247+
if t1 in fp4_dtypes:
248+
continue
241249
self.assertEqual(np.dtype(np.complex128),
242250
dtypes.promote_types(t1, np.complex128))
243251

@@ -247,6 +255,8 @@ def testPromoteDtypesStandard(self):
247255
continue
248256
if t2 in intn_dtypes:
249257
continue
258+
if t2 in fp4_dtypes:
259+
continue
250260
# Symmetry
251261
self.assertEqual(dtypes.promote_types(t1, t2),
252262
dtypes.promote_types(t2, t1))
@@ -261,6 +271,8 @@ def testPromoteDtypesStandard(self):
261271
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
262272
if t in fp8_dtypes:
263273
continue
274+
if t in fp4_dtypes:
275+
continue
264276
if t in intn_dtypes or i in intn_dtypes:
265277
continue
266278
self.assertEqual(t, dtypes.promote_types(t, i))
@@ -951,10 +963,12 @@ def testUnaryPromotion(self, dtype, weak_type):
951963
self.skipTest("XLA support for int2 and int4 is incomplete.")
952964
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
953965
self.skipTest("TPU does not support float8_e8m0fnu.")
966+
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
967+
self.skipTest("TPU does not support float4_e2m1fn.")
954968
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
955969
if weak_type:
956970
expected = dtypes.canonicalize_dtype(
957-
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind])
971+
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind])
958972
else:
959973
expected = x.dtype
960974
self.assertEqual(dtypes.result_type(x), expected)
@@ -971,6 +985,18 @@ def testFloat8PromotionError(self):
971985
".*8-bit floats do not support implicit promotion"):
972986
x + y
973987

988+
@jax.numpy_dtype_promotion('standard')
989+
def testFloat4PromotionError(self):
990+
for dtype in fp4_dtypes:
991+
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
992+
# TPU does not support float4_e2m1fn.
993+
continue
994+
x = jnp.array(1, dtype=dtype)
995+
y = jnp.array(1, dtype='float32')
996+
with self.assertRaisesRegex(dtypes.TypePromotionError,
997+
".*4-bit floats do not support implicit promotion"):
998+
x + y
999+
9741000
@jax.numpy_dtype_promotion('standard')
9751001
@jtu.run_on_devices('tpu')
9761002
def testInt2PromotionError(self):
@@ -995,6 +1021,8 @@ def testInt2PromotionError(self):
9951021
def testBinaryNonPromotion(self, dtype, weak_type, promotion):
9961022
if dtype in fp8_dtypes:
9971023
self.skipTest("XLA support for float8 is incomplete.")
1024+
if dtype in fp4_dtypes:
1025+
self.skipTest("XLA support for float4 is incomplete.")
9981026
if dtype in intn_dtypes:
9991027
self.skipTest("XLA support for int2 and int4 is incomplete.")
10001028
# Regression test for https://github.com/jax-ml/jax/issues/6051
@@ -1027,6 +1055,8 @@ def testArrayRepr(self, dtype, weak_type):
10271055
self.skipTest('XLA support for int2 is incomplete.')
10281056
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
10291057
self.skipTest('TPU does not support float8_e8m0fnu.')
1058+
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
1059+
self.skipTest('TPU does not support float4_e2m1fn.')
10301060
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
10311061
rep = repr(val)
10321062
self.assertStartsWith(rep, 'Array(')

tests/export_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,8 @@ def test_poly_numeric_dtypes(self, dtype=np.int32):
10141014
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
10151015
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
10161016
self.skipTest("TPU does not support float8_e8m0fnu.")
1017+
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
1018+
self.skipTest("TPU does not support float4_e2m1fn.")
10171019
@jax.jit
10181020
def f_jax(x):
10191021
return x + x

0 commit comments

Comments
 (0)