diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 9cde235f77..f0ed46decf 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -497,7 +497,6 @@ py_library( srcs = ["dataclass_compat.py"], srcs_version = "PY2AND3", deps = [ - "//tensorboard/backend:process_graph", "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index e0592854d6..c21bd13daf 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -25,11 +25,7 @@ from __future__ import division from __future__ import print_function - -from google.protobuf import message -from tensorboard.backend import process_graph from tensorboard.compat.proto import event_pb2 -from tensorboard.compat.proto import graph_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.compat.proto import types_pb2 from tensorboard.plugins.graph import metadata as graphs_metadata @@ -39,60 +35,32 @@ from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.text import metadata as text_metadata from tensorboard.util import tensor_util -from tensorboard.util import tb_logging - -logger = tb_logging.get_logger() -def migrate_event(event, experimental_filter_graph=False): +def migrate_event(event): """Migrate an event to a sequence of events. Args: event: An `event_pb2.Event`. The caller transfers ownership of the event to this method; the event may be mutated, and may or may not appear in the returned sequence. - experimental_filter_graph: When a graph event is encountered, process the - GraphDef to filter out attributes that are too large to be shown in the - graph UI. Returns: A sequence of `event_pb2.Event`s to use instead of `event`. """ if event.HasField("graph_def"): - return _migrate_graph_event( - event, experimental_filter_graph=experimental_filter_graph - ) + return _migrate_graph_event(event) if event.HasField("summary"): return _migrate_summary_event(event) return (event,) -def _migrate_graph_event(old_event, experimental_filter_graph=False): +def _migrate_graph_event(old_event): result = event_pb2.Event() result.wall_time = old_event.wall_time result.step = old_event.step value = result.summary.value.add(tag=graphs_metadata.RUN_GRAPH_NAME) graph_bytes = old_event.graph_def - - # TODO(@davidsoergel): Move this stopgap to a more appropriate place. - if experimental_filter_graph: - try: - graph_def = graph_pb2.GraphDef().FromString(graph_bytes) - # The reason for the RuntimeWarning catch here is b/27494216, whereby - # some proto parsers incorrectly raise that instead of DecodeError - # on certain kinds of malformed input. Triggering this seems to require - # a combination of mysterious circumstances. - except (message.DecodeError, RuntimeWarning): - logger.warning( - "Could not parse GraphDef of size %d. Skipping.", - len(graph_bytes), - ) - return (old_event,) - # Use the default filter parameters: - # limit_attr_size=1024, large_attrs_key="_too_large_attrs" - process_graph.prepare_graph_for_ui(graph_def) - graph_bytes = graph_def.SerializeToString() - value.tensor.CopyFrom(tensor_util.make_tensor_proto([graph_bytes])) value.metadata.plugin_data.plugin_name = graphs_metadata.PLUGIN_NAME # `value.metadata.plugin_data.content` left as the empty proto diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index 4105b92338..d8f6bf3842 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -51,13 +51,11 @@ class MigrateEventTest(tf.test.TestCase): """Tests for `migrate_event`.""" - def _migrate_event(self, old_event, experimental_filter_graph=False): + def _migrate_event(self, old_event): """Like `migrate_event`, but performs some sanity checks.""" old_event_copy = event_pb2.Event() old_event_copy.CopyFrom(old_event) - new_events = dataclass_compat.migrate_event( - old_event, experimental_filter_graph - ) + new_events = dataclass_compat.migrate_event(old_event) for event in new_events: # ensure that wall time and step are preserved self.assertEqual(event.wall_time, old_event.wall_time) self.assertEqual(event.step, old_event.step) @@ -223,108 +221,6 @@ def test_graph_def(self): self.assertProtoEquals(graph_def, new_graph_def) - def test_graph_def_experimental_filter_graph(self): - # Create a `GraphDef` - graph_def = graph_pb2.GraphDef() - graph_def.node.add(name="alice", op="Person") - graph_def.node.add(name="bob", op="Person") - - graph_def.node[1].attr["small"].s = b"small_attr_value" - graph_def.node[1].attr["large"].s = ( - b"large_attr_value" * 100 # 1600 bytes > 1024 limit - ) - graph_def.node.add( - name="friendship", op="Friendship", input=["alice", "bob"] - ) - - # Simulate legacy graph event - old_event = event_pb2.Event() - old_event.step = 0 - old_event.wall_time = 456.75 - old_event.graph_def = graph_def.SerializeToString() - - new_events = self._migrate_event( - old_event, experimental_filter_graph=True - ) - - new_event = new_events[1] - tensor = tensor_util.make_ndarray(new_event.summary.value[0].tensor) - new_graph_def_bytes = tensor[0] - new_graph_def = graph_pb2.GraphDef.FromString(new_graph_def_bytes) - - expected_graph_def = graph_pb2.GraphDef() - expected_graph_def.CopyFrom(graph_def) - del expected_graph_def.node[1].attr["large"] - expected_graph_def.node[1].attr["_too_large_attrs"].list.s.append( - b"large" - ) - - self.assertProtoEquals(expected_graph_def, new_graph_def) - - def test_graph_def_experimental_filter_graph_corrupt(self): - # Simulate legacy graph event with an unparseable graph. - # We can't be sure whether this will produce `DecodeError` or - # `RuntimeWarning`, so we also check both cases below. - old_event = event_pb2.Event() - old_event.step = 0 - old_event.wall_time = 456.75 - # Careful: some proto parsers choke on byte arrays filled with 0, but - # others don't (silently producing an empty proto, I guess). - # Thus `old_event.graph_def = bytes(1024)` is an unreliable example. - old_event.graph_def = b"" - - new_events = self._migrate_event( - old_event, experimental_filter_graph=True - ) - # _migrate_event emits both the original event and the migrated event, - # but here there is no migrated event becasue the graph was unparseable. - self.assertLen(new_events, 1) - self.assertProtoEquals(new_events[0], old_event) - - def test_graph_def_experimental_filter_graph_DecodeError(self): - # Simulate raising DecodeError when parsing a graph event - old_event = event_pb2.Event() - old_event.step = 0 - old_event.wall_time = 456.75 - old_event.graph_def = b"" - - with mock.patch( - "tensorboard.compat.proto.graph_pb2.GraphDef" - ) as mockGraphDef: - instance = mockGraphDef.return_value - instance.FromString.side_effect = message.DecodeError - - new_events = self._migrate_event( - old_event, experimental_filter_graph=True - ) - - # _migrate_event emits both the original event and the migrated event, - # but here there is no migrated event becasue the graph was unparseable. - self.assertLen(new_events, 1) - self.assertProtoEquals(new_events[0], old_event) - - def test_graph_def_experimental_filter_graph_RuntimeWarning(self): - # Simulate raising RuntimeWarning when parsing a graph event - old_event = event_pb2.Event() - old_event.step = 0 - old_event.wall_time = 456.75 - old_event.graph_def = b"" - - with mock.patch( - "tensorboard.compat.proto.graph_pb2.GraphDef" - ) as mockGraphDef: - instance = mockGraphDef.return_value - instance.FromString.side_effect = RuntimeWarning - - new_events = self._migrate_event( - old_event, experimental_filter_graph=True - ) - - # _migrate_event emits both the original event and the migrated event, - # but here there is no migrated event becasue the graph was unparseable. - self.assertLen(new_events, 1) - self.assertProtoEquals(new_events[0], old_event) - if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index c18fd1c87f..b87c30aa82 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -99,6 +99,7 @@ py_library( "//tensorboard:data_compat", "//tensorboard:dataclass_compat", "//tensorboard:expect_grpc_installed", + "//tensorboard/backend:process_graph", "//tensorboard/backend/event_processing:directory_loader", "//tensorboard/backend/event_processing:event_file_loader", "//tensorboard/backend/event_processing:io_wrapper", @@ -109,6 +110,7 @@ py_library( "//tensorboard/util:grpc_util", "//tensorboard/util:tb_logging", "//tensorboard/util:tensor_util", + "@com_google_protobuf//:protobuf_python", "@org_pythonhosted_six", ], ) @@ -125,6 +127,7 @@ py_test( "//tensorboard:expect_grpc_testing_installed", "//tensorboard:expect_tensorflow_installed", "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:summary_v2", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:summary_v2", @@ -132,6 +135,7 @@ py_test( "//tensorboard/uploader/proto:protos_all_py_pb2", "//tensorboard/uploader/proto:protos_all_py_pb2_grpc", "//tensorboard/util:test_util", + "@com_google_protobuf//:protobuf_python", "@org_pythonhosted_mock", ], ) diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index e5168aef23..f413a1cf38 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -25,16 +25,21 @@ import grpc import six +from google.protobuf import message +from tensorboard.compat.proto import graph_pb2 from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import types_pb2 from tensorboard.uploader.proto import write_service_pb2 from tensorboard.uploader.proto import experiment_pb2 from tensorboard.uploader import logdir_loader from tensorboard.uploader import util from tensorboard import data_compat from tensorboard import dataclass_compat +from tensorboard.backend import process_graph from tensorboard.backend.event_processing import directory_loader from tensorboard.backend.event_processing import event_file_loader from tensorboard.backend.event_processing import io_wrapper +from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.scalar import metadata as scalar_metadata from tensorboard.util import grpc_util from tensorboard.util import tb_logging @@ -425,12 +430,11 @@ def _run_values(self, run_to_events): for (run_name, events) in six.iteritems(run_to_events): for event in events: v2_event = data_compat.migrate_event(event) - dataclass_events = dataclass_compat.migrate_event( - v2_event, experimental_filter_graph=True - ) - for dataclass_event in dataclass_events: - if dataclass_event.summary: - for value in dataclass_event.summary.value: + events = dataclass_compat.migrate_event(v2_event) + events = _filter_graph_defs(events) + for event in events: + if event.summary: + for value in event.summary.value: yield (run_name, event, value) @@ -833,3 +837,41 @@ def _varint_cost(n): result += 1 n >>= 7 return result + + +def _filter_graph_defs(events): + for e in events: + for v in e.summary.value: + if ( + v.metadata.plugin_data.plugin_name + != graphs_metadata.PLUGIN_NAME + ): + continue + if v.tag == graphs_metadata.RUN_GRAPH_NAME: + data = list(v.tensor.string_val) + filtered_data = [_filtered_graph_bytes(x) for x in data] + filtered_data = [x for x in filtered_data if x is not None] + if filtered_data != data: + new_tensor = tensor_util.make_tensor_proto( + filtered_data, dtype=types_pb2.DT_STRING + ) + v.tensor.CopyFrom(new_tensor) + yield e + + +def _filtered_graph_bytes(graph_bytes): + try: + graph_def = graph_pb2.GraphDef().FromString(graph_bytes) + # The reason for the RuntimeWarning catch here is b/27494216, whereby + # some proto parsers incorrectly raise that instead of DecodeError + # on certain kinds of malformed input. Triggering this seems to require + # a combination of mysterious circumstances. + except (message.DecodeError, RuntimeWarning): + logger.warning( + "Could not parse GraphDef of size %d. Skipping.", len(graph_bytes), + ) + return None + # Use the default filter parameters: + # limit_attr_size=1024, large_attrs_key="_too_large_attrs" + process_graph.prepare_graph_for_ui(graph_def) + return graph_def.SerializeToString() diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index b14fbb6e65..32da8024bb 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -33,6 +33,7 @@ import tensorflow as tf +from google.protobuf import message from tensorboard.uploader.proto import experiment_pb2 from tensorboard.uploader.proto import scalar_pb2 from tensorboard.uploader.proto import write_service_pb2 @@ -359,6 +360,67 @@ def test_upload_skip_large_blob(self): self.assertEqual(0, mock_rate_limiter.tick.call_count) self.assertEqual(1, mock_blob_rate_limiter.tick.call_count) + def test_filter_graphs(self): + # Three graphs: one short, one long, one corrupt. + bytes_0 = _create_example_graph_bytes(123) + bytes_1 = _create_example_graph_bytes(9999) + # invalid (truncated) proto: length-delimited field 1 (0x0a) of + # length 0x7f specified, but only len("bogus") = 5 bytes given + # + bytes_2 = b"\x0a\x7fbogus" + + logdir = self.get_temp_dir() + for (i, b) in enumerate([bytes_0, bytes_1, bytes_2]): + run_dir = os.path.join(logdir, "run_%04d" % i) + event = event_pb2.Event(step=0, wall_time=123 * i, graph_def=b) + with tb_test_util.FileWriter(run_dir) as writer: + writer.add_event(event) + + limiter = mock.create_autospec(util.RateLimiter) + limiter.tick.side_effect = [None, AbortUploadError] + mock_client = _create_mock_client() + uploader = _create_uploader( + mock_client, + logdir, + logdir_poll_rate_limiter=limiter, + allowed_plugins=[ + scalars_metadata.PLUGIN_NAME, + graphs_metadata.PLUGIN_NAME, + ], + ) + uploader.create_experiment() + + with self.assertRaises(AbortUploadError): + uploader.start_uploading() + + actual_blobs = [] + for call in mock_client.WriteBlob.call_args_list: + requests = call[0][0] + actual_blobs.append(b"".join(r.data for r in requests)) + + actual_graph_defs = [] + for blob in actual_blobs: + try: + actual_graph_defs.append(graph_pb2.GraphDef.FromString(blob)) + except message.DecodeError: + actual_graph_defs.append(None) + + with self.subTest("graphs with small attr values should be unchanged"): + expected_graph_def_0 = graph_pb2.GraphDef.FromString(bytes_0) + self.assertEqual(actual_graph_defs[0], expected_graph_def_0) + + with self.subTest("large attr values should be filtered out"): + expected_graph_def_1 = graph_pb2.GraphDef.FromString(bytes_1) + del expected_graph_def_1.node[1].attr["large"] + expected_graph_def_1.node[1].attr["_too_large_attrs"].list.s.append( + b"large" + ) + requests = list(mock_client.WriteBlob.call_args[0][0]) + self.assertEqual(actual_graph_defs[1], expected_graph_def_1) + + with self.subTest("corrupt graphs should be skipped"): + self.assertLen(actual_blobs, 2) + def test_upload_server_error(self): mock_client = _create_mock_client() mock_rate_limiter = mock.create_autospec(util.RateLimiter)