diff --git a/tensorboard/plugins/mesh/mesh_plugin.py b/tensorboard/plugins/mesh/mesh_plugin.py index 6c7020cca4..62f093597f 100644 --- a/tensorboard/plugins/mesh/mesh_plugin.py +++ b/tensorboard/plugins/mesh/mesh_plugin.py @@ -48,46 +48,27 @@ def __init__(self, context): """ # Retrieve the multiplexer from the context and store a reference to it. self._multiplexer = context.multiplexer - self._tag_to_instance_tags = collections.defaultdict(list) - self._instance_tag_to_tag = dict() - self._instance_tag_to_metadata = dict() - self.prepare_metadata() - - def prepare_metadata(self): - """Processes all tags and caches metadata for each.""" - if self._tag_to_instance_tags: - return - # This is a dictionary mapping from run to (tag to string content). - # To be clear, the values of the dictionary are dictionaries. - all_runs = self._multiplexer.PluginRunToTagToContent( - MeshPlugin.plugin_name - ) - - # tagToContent is itself a dictionary mapping tag name to string - # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary - # to obtain a list of tags associated with each run. For each tag, estimate - # the number of samples. - self._tag_to_instance_tags = collections.defaultdict(list) - self._instance_tag_to_metadata = dict() - for run, tag_to_content in six.iteritems(all_runs): - for tag, content in six.iteritems(tag_to_content): - meta = metadata.parse_plugin_metadata(content) - self._instance_tag_to_metadata[(run, tag)] = meta - # Remember instance_name (instance_tag) for future reference. - self._tag_to_instance_tags[(run, meta.name)].append(tag) - self._instance_tag_to_tag[(run, tag)] = meta.name def _instance_tag_metadata(self, run, instance_tag): """Gets the `MeshPluginData` proto for an instance tag.""" - return self._instance_tag_to_metadata[(run, instance_tag)] + summary_metadata = self._multiplexer.SummaryMetadata(run, instance_tag) + content = summary_metadata.plugin_data.content + return metadata.parse_plugin_metadata(content) def _tag(self, run, instance_tag): """Gets the user-facing tag name for an instance tag.""" - return self._instance_tag_to_tag[(run, instance_tag)] + return self._instance_tag_metadata(run, instance_tag).name def _instance_tags(self, run, tag): """Gets the instance tag names for a user-facing tag.""" - return self._tag_to_instance_tags[(run, tag)] + index = self._multiplexer.GetAccumulator(run).PluginTagToContent( + metadata.PLUGIN_NAME + ) + return [ + instance_tag + for (instance_tag, content) in six.iteritems(index) + if tag == metadata.parse_plugin_metadata(content).name + ] @wrappers.Request.application def _serve_tags(self, request): @@ -107,9 +88,6 @@ def _serve_tags(self, request): MeshPlugin.plugin_name ) - # Make sure we populate tags mapping structures. - self.prepare_metadata() - # tagToContent is itself a dictionary mapping tag name to string # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary # to obtain a list of tags associated with each run. For each tag estimate @@ -204,11 +182,6 @@ def _collect_tensor_events(self, request, step=None): run = request.args.get("run") tag = request.args.get("tag") - # TODO(b/128995556): investigate why this additional metadata mapping is - # necessary, it must have something todo with the lifecycle of the request. - # Make sure we populate tags mapping structures. - self.prepare_metadata() - tensor_events = [] # List of tuples (meta, tensor) that contain tag. for instance_tag in self._instance_tags(run, tag): tensors = self._multiplexer.Tensors(run, instance_tag) diff --git a/tensorboard/plugins/mesh/mesh_plugin_test.py b/tensorboard/plugins/mesh/mesh_plugin_test.py index 21452be7b9..88d3875c1e 100644 --- a/tensorboard/plugins/mesh/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh/mesh_plugin_test.py @@ -262,17 +262,6 @@ def testsEventsAlwaysSortedByStep(self): # belong to the same mesh. self.assertLessEqual(metadata[i - 1]["step"], metadata[i]["step"]) - @mock.patch.object( - event_multiplexer.EventMultiplexer, - "PluginRunToTagToContent", - return_value={"bar": {"foo": "".encode("utf-8")}}, - ) - def testMetadataComputedOnce(self, run_to_tag_mock): - """Tests that metadata mapping computed once.""" - self.plugin.prepare_metadata() - self.plugin.prepare_metadata() - self.assertEqual(1, run_to_tag_mock.call_count) - def testIsActive(self): self.assertTrue(self.plugin.is_active())