Skip to content

Commit e9fb769

Browse files
davidsoergelbileschi
authored andcommitted
Catch RuntimeWarning when parsing GraphDef protos. (#3508)
Under certain mysterious circumstances, attempting to parse a malformed serialized `GraphDef` can result in the raising of a `RuntimeWarning` instead of a `DecodeError`. Here we catch both in order to treat them in the same way.
1 parent 339be0e commit e9fb769

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

tensorboard/dataclass_compat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def _migrate_graph_event(old_event, experimental_filter_graph=False):
7878
if experimental_filter_graph:
7979
try:
8080
graph_def = graph_pb2.GraphDef().FromString(graph_bytes)
81-
except message.DecodeError:
81+
# The reason for the RuntimeWarning catch here is b/27494216, whereby
82+
# some proto parsers incorrectly raise that instead of DecodeError
83+
# on certain kinds of malformed input. Triggering this seems to require
84+
# a combination of mysterious circumstances.
85+
except (message.DecodeError, RuntimeWarning):
8286
logger.warning(
8387
"Could not parse GraphDef of size %d. Skipping.",
8488
len(graph_bytes),

tensorboard/dataclass_compat_test.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import numpy as np
2424
import tensorflow as tf
2525

26+
from google.protobuf import message
27+
2628
from tensorboard import dataclass_compat
2729
from tensorboard.backend.event_processing import event_file_loader
2830
from tensorboard.compat.proto import event_pb2
@@ -39,6 +41,12 @@
3941
from tensorboard.util import tensor_util
4042
from tensorboard.util import test_util
4143

44+
try:
45+
# python version >= 3.3
46+
from unittest import mock
47+
except ImportError:
48+
import mock # pylint: disable=unused-import
49+
4250

4351
class MigrateEventTest(tf.test.TestCase):
4452
"""Tests for `migrate_event`."""
@@ -254,14 +262,16 @@ def test_graph_def_experimental_filter_graph(self):
254262
self.assertProtoEquals(expected_graph_def, new_graph_def)
255263

256264
def test_graph_def_experimental_filter_graph_corrupt(self):
257-
# Simulate legacy graph event with an unparseable graph
265+
# Simulate legacy graph event with an unparseable graph.
266+
# We can't be sure whether this will produce `DecodeError` or
267+
# `RuntimeWarning`, so we also check both cases below.
258268
old_event = event_pb2.Event()
259269
old_event.step = 0
260270
old_event.wall_time = 456.75
261271
# Careful: some proto parsers choke on byte arrays filled with 0, but
262272
# others don't (silently producing an empty proto, I guess).
263273
# Thus `old_event.graph_def = bytes(1024)` is an unreliable example.
264-
old_event.graph_def = b"bogus"
274+
old_event.graph_def = b"<malformed>"
265275

266276
new_events = self._migrate_event(
267277
old_event, experimental_filter_graph=True
@@ -271,6 +281,50 @@ def test_graph_def_experimental_filter_graph_corrupt(self):
271281
self.assertLen(new_events, 1)
272282
self.assertProtoEquals(new_events[0], old_event)
273283

284+
def test_graph_def_experimental_filter_graph_DecodeError(self):
285+
# Simulate raising DecodeError when parsing a graph event
286+
old_event = event_pb2.Event()
287+
old_event.step = 0
288+
old_event.wall_time = 456.75
289+
old_event.graph_def = b"<malformed>"
290+
291+
with mock.patch(
292+
"tensorboard.compat.proto.graph_pb2.GraphDef"
293+
) as mockGraphDef:
294+
instance = mockGraphDef.return_value
295+
instance.FromString.side_effect = message.DecodeError
296+
297+
new_events = self._migrate_event(
298+
old_event, experimental_filter_graph=True
299+
)
300+
301+
# _migrate_event emits both the original event and the migrated event,
302+
# but here there is no migrated event becasue the graph was unparseable.
303+
self.assertLen(new_events, 1)
304+
self.assertProtoEquals(new_events[0], old_event)
305+
306+
def test_graph_def_experimental_filter_graph_RuntimeWarning(self):
307+
# Simulate raising RuntimeWarning when parsing a graph event
308+
old_event = event_pb2.Event()
309+
old_event.step = 0
310+
old_event.wall_time = 456.75
311+
old_event.graph_def = b"<malformed>"
312+
313+
with mock.patch(
314+
"tensorboard.compat.proto.graph_pb2.GraphDef"
315+
) as mockGraphDef:
316+
instance = mockGraphDef.return_value
317+
instance.FromString.side_effect = RuntimeWarning
318+
319+
new_events = self._migrate_event(
320+
old_event, experimental_filter_graph=True
321+
)
322+
323+
# _migrate_event emits both the original event and the migrated event,
324+
# but here there is no migrated event becasue the graph was unparseable.
325+
self.assertLen(new_events, 1)
326+
self.assertProtoEquals(new_events[0], old_event)
327+
274328

275329
if __name__ == "__main__":
276330
tf.test.main()

0 commit comments

Comments
 (0)