Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def _migrate_graph_event(old_event, experimental_filter_graph=False):
if experimental_filter_graph:
try:
graph_def = graph_pb2.GraphDef().FromString(graph_bytes)
except message.DecodeError:
# The reason for the RuntimeWarning catch here is b/27494216, whereby
# some proto parsers incorrectly raise that instead of DecodeError
# on certain kinds of malformed input. Triggering this seems to require
# a combination of mysterious circumstances.
except (message.DecodeError, RuntimeWarning):
logger.warning(
"Could not parse GraphDef of size %d. Skipping.",
len(graph_bytes),
Expand Down
58 changes: 56 additions & 2 deletions tensorboard/dataclass_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import numpy as np
import tensorflow as tf

from google.protobuf import message

from tensorboard import dataclass_compat
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.compat.proto import event_pb2
Expand All @@ -39,6 +41,12 @@
from tensorboard.util import tensor_util
from tensorboard.util import test_util

try:
# python version >= 3.3
from unittest import mock
except ImportError:
import mock # pylint: disable=unused-import


class MigrateEventTest(tf.test.TestCase):
"""Tests for `migrate_event`."""
Expand Down Expand Up @@ -254,14 +262,16 @@ def test_graph_def_experimental_filter_graph(self):
self.assertProtoEquals(expected_graph_def, new_graph_def)

def test_graph_def_experimental_filter_graph_corrupt(self):
# Simulate legacy graph event with an unparseable graph
# Simulate legacy graph event with an unparseable graph.
# We can't be sure whether this will produce `DecodeError` or
# `RuntimeWarning`, so we also check both cases below.
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
# Careful: some proto parsers choke on byte arrays filled with 0, but
# others don't (silently producing an empty proto, I guess).
# Thus `old_event.graph_def = bytes(1024)` is an unreliable example.
old_event.graph_def = b"bogus"
old_event.graph_def = b"<malformed>"

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
Expand All @@ -271,6 +281,50 @@ def test_graph_def_experimental_filter_graph_corrupt(self):
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)

def test_graph_def_experimental_filter_graph_DecodeError(self):
# Simulate raising DecodeError when parsing a graph event
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
old_event.graph_def = b"<malformed>"

with mock.patch(
"tensorboard.compat.proto.graph_pb2.GraphDef"
) as mockGraphDef:
instance = mockGraphDef.return_value
instance.FromString.side_effect = message.DecodeError

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)

# _migrate_event emits both the original event and the migrated event,
# but here there is no migrated event becasue the graph was unparseable.
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)

def test_graph_def_experimental_filter_graph_RuntimeWarning(self):
# Simulate raising RuntimeWarning when parsing a graph event
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
old_event.graph_def = b"<malformed>"

with mock.patch(
"tensorboard.compat.proto.graph_pb2.GraphDef"
) as mockGraphDef:
instance = mockGraphDef.return_value
instance.FromString.side_effect = RuntimeWarning

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)

# _migrate_event emits both the original event and the migrated event,
# but here there is no migrated event becasue the graph was unparseable.
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)


if __name__ == "__main__":
tf.test.main()