Skip to content

Commit

Permalink
[jax2tf] Fix bfloat16 failures on CPU/GPU with latest tf-nightly. (#4060
Browse files Browse the repository at this point in the history
)
  • Loading branch information
bchetioui authored Aug 14, 2020
1 parent aea64e8 commit bd14f23
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,13 @@ def test_select_and_gather_add(self, harness: primitive_harness.Harness):
def test_unary_elementwise(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]
lax_name = harness.params["lax_name"]
if (lax_name in ("acosh", "asinh", "atanh", "bessel_i0e", "bessel_i1e", "digamma",
"erf", "erf_inv", "erfc", "lgamma", "round", "rsqrt") and
dtype is dtypes.bfloat16 and
jtu.device_under_test() in ["cpu", "gpu"]):
raise unittest.SkipTest(f"bfloat16 support is missing from '{lax_name}' TF kernel on {jtu.device_under_test()} devices.")
# TODO(bchetioui): do they have bfloat16 support, though?
if lax_name in ("sinh", "cosh", "atanh", "asinh", "acosh") and dtype is np.float16:
if lax_name in ("sinh", "cosh", "atanh", "asinh", "acosh", "erf_inv") and dtype is np.float16:
raise unittest.SkipTest("b/158006398: float16 support is missing from '%s' TF kernel" % lax_name)
arg, = harness.dyn_args_maker(self.rng())
custom_assert = None
Expand Down

0 comments on commit bd14f23

Please sign in to comment.