diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index a9b001a381..e0592854d6 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -78,7 +78,11 @@ def _migrate_graph_event(old_event, experimental_filter_graph=False): if experimental_filter_graph: try: graph_def = graph_pb2.GraphDef().FromString(graph_bytes) - except message.DecodeError: + # The reason for the RuntimeWarning catch here is b/27494216, whereby + # some proto parsers incorrectly raise that instead of DecodeError + # on certain kinds of malformed input. Triggering this seems to require + # a combination of mysterious circumstances. + except (message.DecodeError, RuntimeWarning): logger.warning( "Could not parse GraphDef of size %d. Skipping.", len(graph_bytes), diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index 5985287826..4105b92338 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -23,6 +23,8 @@ import numpy as np import tensorflow as tf +from google.protobuf import message + from tensorboard import dataclass_compat from tensorboard.backend.event_processing import event_file_loader from tensorboard.compat.proto import event_pb2 @@ -39,6 +41,12 @@ from tensorboard.util import tensor_util from tensorboard.util import test_util +try: + # python version >= 3.3 + from unittest import mock +except ImportError: + import mock # pylint: disable=unused-import + class MigrateEventTest(tf.test.TestCase): """Tests for `migrate_event`.""" @@ -254,14 +262,16 @@ def test_graph_def_experimental_filter_graph(self): self.assertProtoEquals(expected_graph_def, new_graph_def) def test_graph_def_experimental_filter_graph_corrupt(self): - # Simulate legacy graph event with an unparseable graph + # Simulate legacy graph event with an unparseable graph. + # We can't be sure whether this will produce `DecodeError` or + # `RuntimeWarning`, so we also check both cases below. old_event = event_pb2.Event() old_event.step = 0 old_event.wall_time = 456.75 # Careful: some proto parsers choke on byte arrays filled with 0, but # others don't (silently producing an empty proto, I guess). # Thus `old_event.graph_def = bytes(1024)` is an unreliable example. - old_event.graph_def = b"bogus" + old_event.graph_def = b"" new_events = self._migrate_event( old_event, experimental_filter_graph=True @@ -271,6 +281,50 @@ def test_graph_def_experimental_filter_graph_corrupt(self): self.assertLen(new_events, 1) self.assertProtoEquals(new_events[0], old_event) + def test_graph_def_experimental_filter_graph_DecodeError(self): + # Simulate raising DecodeError when parsing a graph event + old_event = event_pb2.Event() + old_event.step = 0 + old_event.wall_time = 456.75 + old_event.graph_def = b"" + + with mock.patch( + "tensorboard.compat.proto.graph_pb2.GraphDef" + ) as mockGraphDef: + instance = mockGraphDef.return_value + instance.FromString.side_effect = message.DecodeError + + new_events = self._migrate_event( + old_event, experimental_filter_graph=True + ) + + # _migrate_event emits both the original event and the migrated event, + # but here there is no migrated event becasue the graph was unparseable. + self.assertLen(new_events, 1) + self.assertProtoEquals(new_events[0], old_event) + + def test_graph_def_experimental_filter_graph_RuntimeWarning(self): + # Simulate raising RuntimeWarning when parsing a graph event + old_event = event_pb2.Event() + old_event.step = 0 + old_event.wall_time = 456.75 + old_event.graph_def = b"" + + with mock.patch( + "tensorboard.compat.proto.graph_pb2.GraphDef" + ) as mockGraphDef: + instance = mockGraphDef.return_value + instance.FromString.side_effect = RuntimeWarning + + new_events = self._migrate_event( + old_event, experimental_filter_graph=True + ) + + # _migrate_event emits both the original event and the migrated event, + # but here there is no migrated event becasue the graph was unparseable. + self.assertLen(new_events, 1) + self.assertProtoEquals(new_events[0], old_event) + if __name__ == "__main__": tf.test.main()