Skip to content

Commit

Permalink
Partially re-enable some tests for TF2.9.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 507798778
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Feb 7, 2023
1 parent c356dea commit 033a54b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
1 change: 0 additions & 1 deletion tensorflow_gnn/runner/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ py_strict_test(
name = "attribution_test",
srcs = ["attribution_test.py"],
srcs_version = "PY3",
tags = ["tf_at_least_2_10"],
deps = [
":attribution",
"//:expect_tensorflow_installed",
Expand Down
20 changes: 20 additions & 0 deletions tensorflow_gnn/runner/utils/attribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class AttributionTest(tf.test.TestCase):
def test_counterfactual_random(self):
counterfactual = attribution.counterfactual(self.gt, random=True, seed=8191)

# TODO(b/266817638): Remove when fixed.
if tf.__version__.startswith("2.9."):
self.skipTest("Bad test: exepected values depend on TF2.10+ RNG seeds")

self.assertAllEqual(
counterfactual.context.features["h"],
as_tensor((0.49280962, 0.466383)))
Expand Down Expand Up @@ -95,6 +99,10 @@ def test_subtract_graph_features(self):
}
}))

# TODO(b/266817638): Remove when fixed.
if tf.__version__.startswith("2.9."):
self.skipTest("Bad test: exepected values depend on TF2.10+ RNG seeds")

self.assertAllClose(
deltas.context.features["h"],
as_tensor((.514 - .4, .433 - .8)))
Expand All @@ -116,6 +124,10 @@ def test_interpolate(self):

self.assertLen(interpolations, 4)

# TODO(b/266817638): Remove when fixed.
if tf.__version__.startswith("2.9."):
self.skipTest("Bad test: exepected values depend on TF2.10+ RNG seeds")

# Interpolation 0
self.assertAllEqual(
interpolations[0].context.features["h"],
Expand Down Expand Up @@ -179,6 +191,10 @@ def test_interpolate(self):
def test_sum_graph_features(self):
summation = attribution.sum_graph_features((self.gt,) * 4)

# TODO(b/266817638): Remove when fixed.
if tf.__version__.startswith("2.9."):
self.skipTest("Bad test: exepected values depend on TF2.10+ RNG seeds")

self.assertAllEqual(
summation.context.features["h"],
as_tensor((.514 * 4, .433 * 4)))
Expand Down Expand Up @@ -258,6 +274,10 @@ def test_integrated_gradients_exporter(self):
outputs = saved_model.signatures["integrated_gradients"](**kwargs)
gt = outputs["output"]

# TODO(b/266817638): Remove when fixed.
if tf.__version__.startswith("2.9."):
self.skipTest("Bad test: exepected values depend on TF2.10+ RNG seeds")

# The above GNN passes a single message over the only edge type before
# collecting a seed node for activations. The above graph is a line:
# seed --weight 0--> node 1 --weight 1--> node 2.
Expand Down

0 comments on commit 033a54b

Please sign in to comment.