diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index 3249b6f954..e853af599c 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -139,6 +139,12 @@ def _migrate_value(value, initial_metadata): return _migrate_hparams_value(value) if plugin_name == pr_curves_metadata.PLUGIN_NAME: return _migrate_pr_curve_value(value) + if plugin_name in [ + graphs_metadata.PLUGIN_NAME_RUN_METADATA, + graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, + graphs_metadata.PLUGIN_NAME_KERAS_MODEL, + ]: + return _migrate_graph_sub_plugin_value(value) return (value,) @@ -191,3 +197,12 @@ def _migrate_pr_curve_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR return (value,) + + +def _migrate_graph_sub_plugin_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE + shape = value.tensor.tensor_shape.dim + if not shape: + shape.add(size=1) + return (value,) diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index d176899705..5d9eb87929 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -370,6 +370,43 @@ def test_run_metadata(self): graphs_metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, ) + def test_graph_sub_plugins(self): + # Tests for `graph_run_metadata`, `graph_run_metadata_graph`, + # and `graph_keras_model` plugins. We fabricate these since it's + # not straightforward to get handles to them. + for plugin_name in [ + graphs_metadata.PLUGIN_NAME_RUN_METADATA, + graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, + graphs_metadata.PLUGIN_NAME_KERAS_MODEL, + ]: + with self.subTest(plugin_name): + old_event = event_pb2.Event() + old_event.step = 123 + old_event.wall_time = 456.75 + old_value = old_event.summary.value.add() + old_value.metadata.plugin_data.plugin_name = plugin_name + old_value.metadata.plugin_data.content = b"1" + old_tensor = tensor_util.make_tensor_proto(b"2+2=4") + # input data are scalar tensors + self.assertEqual(tensor_util.make_ndarray(old_tensor).shape, ()) + old_value.tensor.CopyFrom(old_tensor) + + new_events = self._migrate_event(old_event) + self.assertLen(new_events, 1) + self.assertLen(new_events[0].summary.value, 1) + new_value = new_events[0].summary.value[0] + ndarray = tensor_util.make_ndarray(new_value.tensor) + self.assertEqual(ndarray.shape, (1,)) + self.assertEqual(ndarray.item(), b"2+2=4") + self.assertEqual( + new_value.metadata.data_class, + summary_pb2.DATA_CLASS_BLOB_SEQUENCE, + ) + self.assertEqual( + new_value.metadata.plugin_data.plugin_name, plugin_name + ) + self.assertEqual(new_value.metadata.plugin_data.content, b"1") + if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/plugins/graph/BUILD b/tensorboard/plugins/graph/BUILD index 6197d6afbc..8c38414b6e 100644 --- a/tensorboard/plugins/graph/BUILD +++ b/tensorboard/plugins/graph/BUILD @@ -14,6 +14,7 @@ py_library( ":graph_util", ":keras_util", ":metadata", + "//tensorboard:errors", "//tensorboard:plugin_util", "//tensorboard/backend:http_util", "//tensorboard/backend:process_graph", @@ -21,7 +22,6 @@ py_library( "//tensorboard/plugins:base_plugin", "@com_google_protobuf//:protobuf_python", "@org_pocoo_werkzeug", - "@org_pythonhosted_six", ], ) diff --git a/tensorboard/plugins/graph/graphs_plugin.py b/tensorboard/plugins/graph/graphs_plugin.py index 4b27a55b9f..aa45e447ad 100644 --- a/tensorboard/plugins/graph/graphs_plugin.py +++ b/tensorboard/plugins/graph/graphs_plugin.py @@ -19,9 +19,9 @@ from __future__ import print_function import json -import six from werkzeug import wrappers +from tensorboard import errors from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.backend import process_graph @@ -48,13 +48,7 @@ def __init__(self, context): Args: context: A base_plugin.TBContext instance. """ - self._multiplexer = context.multiplexer - if not self._multiplexer or ( - context.flags and context.flags.generic_data == "true" - ): - self._data_provider = context.data_provider - else: - self._data_provider = None + self._data_provider = context.data_provider def get_plugin_apps(self): return { @@ -112,31 +106,17 @@ def add_row_item(run, tag=None): ) return (run_item, tag_item) - if self._data_provider: - mapping = self._data_provider.list_blob_sequences( - ctx, - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - ) - for (run_name, tag_to_time_series) in six.iteritems(mapping): - for tag in tag_to_time_series: - if tag == metadata.RUN_GRAPH_NAME: - (run_item, _) = add_row_item(run_name, None) - run_item["run_graph"] = True - else: - (_, tag_item) = add_row_item(run_name, tag) - tag_item["op_graph"] = True - return result - - mapping = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH + mapping = self._data_provider.list_blob_sequences( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, ) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): + for (run_name, tags) in mapping.items(): + for (tag, tag_data) in tags.items(): # The Summary op is defined in TensorFlow and does not use a stringified proto # as a content of plugin data. It contains single string that denotes a version. # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790 - if content != b"1": + if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) @@ -146,12 +126,14 @@ def add_row_item(run, tag=None): # Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA # contain both op graph and profile information. - mapping = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME_RUN_METADATA + mapping = self._data_provider.list_blob_sequences( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME_RUN_METADATA, ) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - if content != b"1": + for (run_name, tags) in mapping.items(): + for (tag, tag_data) in tags.items(): + if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) @@ -162,12 +144,14 @@ def add_row_item(run, tag=None): # Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL # contain serialized Keras model in JSON format. - mapping = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME_KERAS_MODEL + mapping = self._data_provider.list_blob_sequences( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL, ) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - if content != b"1": + for (run_name, tags) in mapping.items(): + for (tag, tag_data) in tags.items(): + if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) @@ -175,8 +159,10 @@ def add_row_item(run, tag=None): (_, tag_item) = add_row_item(run_name, tag) tag_item["conceptual_graph"] = True - mapping = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME + mapping = self._data_provider.list_blob_sequences( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, ) for (run_name, tags) in mapping.items(): if metadata.RUN_GRAPH_NAME in tags: @@ -184,8 +170,10 @@ def add_row_item(run, tag=None): run_item["run_graph"] = True # Top level `Event.tagged_run_metadata` represents profile data only. - mapping = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME_TAGGED_RUN_METADATA + mapping = self._data_provider.list_blob_sequences( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, ) for (run_name, tags) in mapping.items(): for tag in tags: @@ -194,58 +182,68 @@ def add_row_item(run, tag=None): return result - def graph_impl( - self, - ctx, - run, - tag, - is_conceptual, - experiment=None, - limit_attr_size=None, - large_attrs_key=None, - ): - """Result of the form `(body, mime_type)`, or `None` if no graph - exists.""" - if self._data_provider: - if tag is None: - tag = metadata.RUN_GRAPH_NAME - graph_blob_sequences = self._data_provider.read_blob_sequences( + def _read_blob(self, ctx, experiment, plugin_names, run, tag): + for plugin_name in plugin_names: + blob_sequences = self._data_provider.read_blob_sequences( ctx, experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, + plugin_name=plugin_name, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), downsample=1, ) - blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) + blob_sequence_data = blob_sequences.get(run, {}).get(tag, ()) try: - blob_ref = blob_datum_list[0].values[0] + blob_ref = blob_sequence_data[0].values[0] except IndexError: - return None - # Always use the blob_key approach for now, even if there is a direct url. - graph_raw = self._data_provider.read_blob( + continue + return self._data_provider.read_blob( ctx, blob_key=blob_ref.blob_key ) - # This method ultimately returns pbtxt, but we have to deserialize and - # later reserialize this anyway, because a) this way we accept binary - # protobufs too, and b) below we run `prepare_graph_for_ui` on the graph. - graph = graph_pb2.GraphDef.FromString(graph_raw) + raise errors.NotFound() - elif is_conceptual: - tensor_events = self._multiplexer.Tensors(run, tag) - # Take the first event if there are multiple events written from different - # steps. + def graph_impl( + self, + ctx, + run, + tag, + is_conceptual, + experiment=None, + limit_attr_size=None, + large_attrs_key=None, + ): + """Result of the form `(body, mime_type)`; may raise `NotFound`.""" + if is_conceptual: keras_model_config = json.loads( - tensor_events[0].tensor_proto.string_val[0] + self._read_blob( + ctx, + experiment, + [metadata.PLUGIN_NAME_KERAS_MODEL], + run, + tag, + ) ) graph = keras_util.keras_model_to_graph_def(keras_model_config) - elif tag: - tensor_events = self._multiplexer.Tensors(run, tag) - # Take the first event if there are multiple events written from different - # steps. - run_metadata = config_pb2.RunMetadata.FromString( - tensor_events[0].tensor_proto.string_val[0] + elif tag is None: + graph_raw = self._read_blob( + ctx, + experiment, + [metadata.PLUGIN_NAME], + run, + metadata.RUN_GRAPH_NAME, ) + graph = graph_pb2.GraphDef.FromString(graph_raw) + + else: + # Op graph: could be either of two plugins. (Cf. `info_impl`.) + plugins = [ + metadata.PLUGIN_NAME_RUN_METADATA, + metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, + ] + raw_run_metadata = self._read_blob( + ctx, experiment, plugins, run, tag + ) + run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) graph = graph_util.merge_graph_defs( [ func_graph.pre_optimization_graph @@ -253,13 +251,6 @@ def graph_impl( ] ) - else: - tensor_events = self._multiplexer.Tensors( - run, metadata.RUN_GRAPH_NAME - ) - graph_raw = tensor_events[0].tensor_proto.string_val[0] - graph = graph_pb2.GraphDef.FromString(graph_raw) - # This next line might raise a ValueError if the limit parameters # are invalid (size is negative, size present but key absent, etc.). process_graph.prepare_graph_for_ui( @@ -267,20 +258,15 @@ def graph_impl( ) return (str(graph), "text/x-protobuf") # pbtxt - def run_metadata_impl(self, run, tag): - """Result of the form `(body, mime_type)`, or `None` if no data - exists.""" - if self._data_provider: - # TODO(davidsoergel, wchargin): Consider plumbing run metadata through data providers. - return None - tensor_events = self._multiplexer.Tensors(run, tag) - if tensor_events is None: - return None - # Take the first event if there are multiple events written from different - # steps. - run_metadata = config_pb2.RunMetadata.FromString( - tensor_events[0].tensor_proto.string_val[0] - ) + def run_metadata_impl(self, ctx, experiment, run, tag): + """Result of the form `(body, mime_type)`; may raise `NotFound`.""" + # Profile graph: could be either of two plugins. (Cf. `info_impl`.) + plugins = [ + metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, + metadata.PLUGIN_NAME_RUN_METADATA, + ] + raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag) + run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) return (str(run_metadata), "text/x-protobuf") # pbtxt @wrappers.Request.application @@ -332,21 +318,14 @@ def graph_route(self, request): ) except ValueError as e: return http_util.Respond(request, e.message, "text/plain", code=400) - else: - if result is not None: - ( - body, - mime_type, - ) = result # pylint: disable=unpacking-non-sequence - return http_util.Respond(request, body, mime_type) - else: - return http_util.Respond( - request, "404 Not Found", "text/plain", code=404 - ) + (body, mime_type) = result + return http_util.Respond(request, body, mime_type) @wrappers.Request.application def run_metadata_route(self, request): """Given a tag and a run, return the session.run() metadata.""" + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) tag = request.args.get("tag") run = request.args.get("run") if tag is None: @@ -357,11 +336,5 @@ def run_metadata_route(self, request): return http_util.Respond( request, 'query parameter "run" is required', "text/plain", 400 ) - result = self.run_metadata_impl(run, tag) - if result is not None: - (body, mime_type) = result # pylint: disable=unpacking-non-sequence - return http_util.Respond(request, body, mime_type) - else: - return http_util.Respond( - request, "404 Not Found", "text/plain", code=404 - ) + (body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag) + return http_util.Respond(request, body, mime_type) diff --git a/tensorboard/plugins/graph/graphs_plugin_test.py b/tensorboard/plugins/graph/graphs_plugin_test.py index a8dc86c4dd..264b82ea23 100644 --- a/tensorboard/plugins/graph/graphs_plugin_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_test.py @@ -64,7 +64,7 @@ def with_runs(run_specs): - """Run a test with a bare multiplexer and with a `data_provider`. + """Run a test with a `data_provider`. The decorated function will receive an initialized `GraphsPlugin` object as its first positional argument. @@ -77,11 +77,6 @@ def decorator(fn): @functools.wraps(fn) def wrapper(self, *args, **kwargs): (logdir, multiplexer) = self.load_runs(run_specs) - with self.subTest("bare multiplexer"): - ctx = base_plugin.TBContext( - logdir=logdir, multiplexer=multiplexer - ) - fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) with self.subTest("generic data provider"): flags = argparse.Namespace(generic_data="true") provider = data_provider.MultiplexerDataProvider( @@ -237,13 +232,6 @@ def test_info(self, plugin): }, } - if plugin._data_provider: - # Hack, for now. - # Data providers don't yet pass RunMetadata, so this entry excludes it. - expected["_RUN_WITH_GRAPH_WITH_METADATA"]["tags"] = {} - # Data providers don't yet pass RunMetadata, so this entry is completely omitted. - del expected["_RUN_WITHOUT_GRAPH_WITH_METADATA"] - actual = plugin.info_impl(context.RequestContext(), "eid") self.assertEqual(expected, actual) @@ -295,17 +283,14 @@ def test_graph_large_attrs(self, plugin): @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) def test_run_metadata(self, plugin): + ctx = context.RequestContext() result = plugin.run_metadata_impl( - _RUN_WITH_GRAPH_WITH_METADATA[0], self._METADATA_TAG + ctx, "123", _RUN_WITH_GRAPH_WITH_METADATA[0], self._METADATA_TAG ) - if plugin._data_provider: - # Hack, for now - self.assertEqual(result, None) - else: - (metadata_pbtxt, mime_type) = result - self.assertEqual(mime_type, "text/x-protobuf") - text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) - # If it parses, we're happy. + (metadata_pbtxt, mime_type) = result + self.assertEqual(mime_type, "text/x-protobuf") + text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) + # If it parses, we're happy. @with_runs([_RUN_WITH_GRAPH_WITHOUT_METADATA]) def test_is_active(self, plugin):