Skip to content

Commit

Permalink
[jax2tf] Disable some jax2tf primitive tests until TF bug is fixed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 500102079
  • Loading branch information
gnecula authored and jax authors committed Jan 6, 2023
1 parent b1d8c71 commit 7cfea0a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
12 changes: 11 additions & 1 deletion jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,17 @@ def rem(cls, harness: primitive_harness.Harness):
np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8,
np.int16, np.int32, np.int64
],
skip_comparison=True,
# Only the harnesses with "singularity" will have divide by 0
enabled=("singularity" in harness.name)),
Jax2TfLimitation(
"TF division of inf by inf returns inf while in JAX returns nan",
dtypes=[
np.float32,
],
devices="gpu",
skip_comparison=True,
enabled=("singularity_inf_by_inf" in harness.name)),
]

@classmethod
Expand Down Expand Up @@ -1264,7 +1273,8 @@ def custom_assert(tst, result_jax, result_tf, *, err_msg, **_):
for arr_jax, arr_tf in zip(result_jax, result_tf):
tst.assertArraysEqual(arr_jax, arr_tf, err_msg=err_msg)
else:
mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(first_arr_tf)
mask_jax = np.isnan(first_arr_jax) | np.isinf(first_arr_jax)
mask_tf = np.isnan(first_arr_tf) | np.isinf(first_arr_tf)
tst.assertArraysEqual(
first_arr_jax[~mask_jax], first_arr_tf[~mask_tf], err_msg=err_msg)

Expand Down
17 changes: 15 additions & 2 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="scatter_modes_out_of_bounds_shape=float32[1,5]",
#one_containing="",
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")
Expand All @@ -115,16 +115,29 @@ def test_prim(self, harness: primitive_harness.Harness):
args = harness.dyn_args_maker(self.rng())
enable_xla = harness.params.get("enable_xla", True)
if config.jax2tf_default_experimental_native_lowering and not enable_xla:
return
raise unittest.SkipTest("experimental_native_lowering not supported with enable_xla=False")
if ("gather_from_take_indices" in harness.fullname and
"fill" in harness.fullname and
not enable_xla and
device in ("tpu",)):
raise unittest.SkipTest("b/262580493")

if (not config.jax_array and
device == "cpu" and
"top_k_sort_inf_nan_inshape=float32[5]_k=5" in harness.fullname):
raise unittest.SkipTest("Unexplained failure, but in old no_jax_array")

associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
try:
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
except Exception as e:
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
if (config.jax2tf_default_experimental_native_lowering and
"does not work with custom calls" in str(e)):
logging.warning("Supressing error %s", e)
raise unittest.SkipTest("b/264596006: custom calls in native lowering fail in TF")
else:
raise e

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def log_message(extra):
if expect_tf_error:
# It is more ergonomic to print all successful modes once
logging.warning(log_message(
f"Unexpected success with known limitations {expect_tf_error}"))
f"Unexpected execution success with known limitations {expect_tf_error}"))
unexpected_successes.append(f"{mode}: {expect_tf_error}")

if (jtu.device_under_test() == "gpu" and
Expand Down

0 comments on commit 7cfea0a

Please sign in to comment.