Skip to content

Commit ccbe9f7

Browse files
committed
Fix lint
1 parent c099e80 commit ccbe9f7

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

jax/_src/dtypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
730730
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
731731
"implicit promotion. If you'd like your inputs to be promoted to another type, "
732732
"you can do so explicitly using e.g. x.astype('float32')")
733+
elif any(n in _float4_dtypes for n in nodes):
734+
msg = (
735+
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
736+
"promotion path. To avoid unintended promotion, 4-bit floats do not support "
737+
"implicit promotion. If you'd like your inputs to be promoted to another type, "
738+
"you can do so explicitly using e.g. x.astype('float32')")
733739
elif any(n in _intn_dtypes for n in nodes):
734740
msg = (
735741
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "

tests/dtypes_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,8 @@ def testFloat8PromotionError(self):
989989
def testFloat4PromotionError(self):
990990
for dtype in fp4_dtypes:
991991
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
992-
self.skipTest("TPU does not support float4_e2m1fn.")
992+
# TPU does not support float4_e2m1fn.
993+
continue
993994
x = jnp.array(1, dtype=dtype)
994995
y = jnp.array(1, dtype='float32')
995996
with self.assertRaisesRegex(dtypes.TypePromotionError,
@@ -1055,7 +1056,7 @@ def testArrayRepr(self, dtype, weak_type):
10551056
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
10561057
self.skipTest('TPU does not support float8_e8m0fnu.')
10571058
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
1058-
self.skipTest('TPU does not support float4_e2m1fn.')
1059+
self.skipTest('TPU does not support float4_e2m1fn.')
10591060
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
10601061
rep = repr(val)
10611062
self.assertStartsWith(rep, 'Array(')

0 commit comments

Comments
 (0)