73
73
float_dtypes += fp8_dtypes
74
74
custom_float_dtypes += fp8_dtypes
75
75
76
+ fp4_dtypes = []
77
+ if dtypes .float4_e2m1fn is not None :
78
+ fp4_dtypes += [np .dtype (dtypes .float4_e2m1fn )]
79
+ float_dtypes += fp4_dtypes
80
+ custom_float_dtypes += fp4_dtypes
81
+
76
82
complex_dtypes = [np .dtype ('complex64' ), np .dtype ('complex128' )]
77
83
78
84
@@ -238,6 +244,8 @@ def testPromoteDtypesStandard(self):
238
244
continue
239
245
if t1 in intn_dtypes :
240
246
continue
247
+ if t1 in fp4_dtypes :
248
+ continue
241
249
self .assertEqual (np .dtype (np .complex128 ),
242
250
dtypes .promote_types (t1 , np .complex128 ))
243
251
@@ -247,6 +255,8 @@ def testPromoteDtypesStandard(self):
247
255
continue
248
256
if t2 in intn_dtypes :
249
257
continue
258
+ if t2 in fp4_dtypes :
259
+ continue
250
260
# Symmetry
251
261
self .assertEqual (dtypes .promote_types (t1 , t2 ),
252
262
dtypes .promote_types (t2 , t1 ))
@@ -261,6 +271,8 @@ def testPromoteDtypesStandard(self):
261
271
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
262
272
if t in fp8_dtypes :
263
273
continue
274
+ if t in fp4_dtypes :
275
+ continue
264
276
if t in intn_dtypes or i in intn_dtypes :
265
277
continue
266
278
self .assertEqual (t , dtypes .promote_types (t , i ))
@@ -951,10 +963,12 @@ def testUnaryPromotion(self, dtype, weak_type):
951
963
self .skipTest ("XLA support for int2 and int4 is incomplete." )
952
964
if dtype == dtypes .float8_e8m0fnu and jtu .test_device_matches (['tpu' ]):
953
965
self .skipTest ("TPU does not support float8_e8m0fnu." )
966
+ if dtype == dtypes .float4_e2m1fn and jtu .test_device_matches (['tpu' ]):
967
+ self .skipTest ("TPU does not support float4_e2m1fn." )
954
968
x = lax_internal ._convert_element_type (0 , dtype , weak_type = weak_type )
955
969
if weak_type :
956
970
expected = dtypes .canonicalize_dtype (
957
- dtypes ._default_types ['f' if x .dtype in ["bfloat16" , * fp8_dtypes ] else x .dtype .kind ])
971
+ dtypes ._default_types ['f' if x .dtype in ["bfloat16" , * fp8_dtypes , * fp4_dtypes ] else x .dtype .kind ])
958
972
else :
959
973
expected = x .dtype
960
974
self .assertEqual (dtypes .result_type (x ), expected )
@@ -971,6 +985,18 @@ def testFloat8PromotionError(self):
971
985
".*8-bit floats do not support implicit promotion" ):
972
986
x + y
973
987
988
+ @jax .numpy_dtype_promotion ('standard' )
989
+ def testFloat4PromotionError (self ):
990
+ for dtype in fp4_dtypes :
991
+ if dtype == dtypes .float4_e2m1fn and jtu .test_device_matches (['tpu' ]):
992
+ # TPU does not support float4_e2m1fn.
993
+ continue
994
+ x = jnp .array (1 , dtype = dtype )
995
+ y = jnp .array (1 , dtype = 'float32' )
996
+ with self .assertRaisesRegex (dtypes .TypePromotionError ,
997
+ ".*4-bit floats do not support implicit promotion" ):
998
+ x + y
999
+
974
1000
@jax .numpy_dtype_promotion ('standard' )
975
1001
@jtu .run_on_devices ('tpu' )
976
1002
def testInt2PromotionError (self ):
@@ -995,6 +1021,8 @@ def testInt2PromotionError(self):
995
1021
def testBinaryNonPromotion (self , dtype , weak_type , promotion ):
996
1022
if dtype in fp8_dtypes :
997
1023
self .skipTest ("XLA support for float8 is incomplete." )
1024
+ if dtype in fp4_dtypes :
1025
+ self .skipTest ("XLA support for float4 is incomplete." )
998
1026
if dtype in intn_dtypes :
999
1027
self .skipTest ("XLA support for int2 and int4 is incomplete." )
1000
1028
# Regression test for https://github.com/jax-ml/jax/issues/6051
@@ -1027,6 +1055,8 @@ def testArrayRepr(self, dtype, weak_type):
1027
1055
self .skipTest ('XLA support for int2 is incomplete.' )
1028
1056
if dtype == dtypes .float8_e8m0fnu and jtu .test_device_matches (['tpu' ]):
1029
1057
self .skipTest ('TPU does not support float8_e8m0fnu.' )
1058
+ if dtype == dtypes .float4_e2m1fn and jtu .test_device_matches (['tpu' ]):
1059
+ self .skipTest ('TPU does not support float4_e2m1fn.' )
1030
1060
val = lax_internal ._convert_element_type (0 , dtype , weak_type = weak_type )
1031
1061
rep = repr (val )
1032
1062
self .assertStartsWith (rep , 'Array(' )
0 commit comments