Skip to content

Commit 9bc0439

Browse files
gneculajax authors
authored andcommitted
Disable flaky python callback test.
PiperOrigin-RevId: 575893965
1 parent 9b1a656 commit 9bc0439

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/jaxpr_effects_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ def f(x):
605605
jax.effects_barrier()
606606
self.assertListEqual(log, [2., 3.])
607607

608+
# TODO(b/307211483): Investigate failure
609+
@jtu.skip_on_devices("tpu")
608610
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
609611
if jax.device_count() < 2:
610612
raise unittest.SkipTest("Test requires >= 2 devices.")
@@ -632,8 +634,8 @@ def g(x):
632634
f(jnp.ones((500, 500)))
633635
g(3.)
634636
jax.effects_barrier()
635-
x_, y_ = float(jnp.log(1.25e8)), 3.
636-
expected_log = [x_, y_, x_, y_, x_, y_]
637+
f_, g_ = float(jnp.log(1.25e8)), 3.
638+
expected_log = [f_, g_, f_, g_, f_, g_]
637639
self.assertListEqual(log, expected_log)
638640

639641
def test_different_threads_get_different_tokens(self):

0 commit comments

Comments
 (0)