Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 12 additions & 39 deletions tensorboard/plugins/mesh/mesh_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,46 +48,27 @@ def __init__(self, context):
"""
# Retrieve the multiplexer from the context and store a reference to it.
self._multiplexer = context.multiplexer
self._tag_to_instance_tags = collections.defaultdict(list)
self._instance_tag_to_tag = dict()
self._instance_tag_to_metadata = dict()
self.prepare_metadata()

def prepare_metadata(self):
"""Processes all tags and caches metadata for each."""
if self._tag_to_instance_tags:
return
# This is a dictionary mapping from run to (tag to string content).
# To be clear, the values of the dictionary are dictionaries.
all_runs = self._multiplexer.PluginRunToTagToContent(
MeshPlugin.plugin_name
)

# tagToContent is itself a dictionary mapping tag name to string
# SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary
# to obtain a list of tags associated with each run. For each tag, estimate
# the number of samples.
self._tag_to_instance_tags = collections.defaultdict(list)
self._instance_tag_to_metadata = dict()
for run, tag_to_content in six.iteritems(all_runs):
for tag, content in six.iteritems(tag_to_content):
meta = metadata.parse_plugin_metadata(content)
self._instance_tag_to_metadata[(run, tag)] = meta
# Remember instance_name (instance_tag) for future reference.
self._tag_to_instance_tags[(run, meta.name)].append(tag)
self._instance_tag_to_tag[(run, tag)] = meta.name

def _instance_tag_metadata(self, run, instance_tag):
"""Gets the `MeshPluginData` proto for an instance tag."""
return self._instance_tag_to_metadata[(run, instance_tag)]
summary_metadata = self._multiplexer.SummaryMetadata(run, instance_tag)
content = summary_metadata.plugin_data.content
return metadata.parse_plugin_metadata(content)

def _tag(self, run, instance_tag):
"""Gets the user-facing tag name for an instance tag."""
return self._instance_tag_to_tag[(run, instance_tag)]
return self._instance_tag_metadata(run, instance_tag).name

def _instance_tags(self, run, tag):
"""Gets the instance tag names for a user-facing tag."""
return self._tag_to_instance_tags[(run, tag)]
index = self._multiplexer.GetAccumulator(run).PluginTagToContent(
metadata.PLUGIN_NAME
)
return [
instance_tag
for (instance_tag, content) in six.iteritems(index)
if tag == metadata.parse_plugin_metadata(content).name
]

@wrappers.Request.application
def _serve_tags(self, request):
Expand All @@ -107,9 +88,6 @@ def _serve_tags(self, request):
MeshPlugin.plugin_name
)

# Make sure we populate tags mapping structures.
self.prepare_metadata()

# tagToContent is itself a dictionary mapping tag name to string
# SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary
# to obtain a list of tags associated with each run. For each tag estimate
Expand Down Expand Up @@ -204,11 +182,6 @@ def _collect_tensor_events(self, request, step=None):
run = request.args.get("run")
tag = request.args.get("tag")

# TODO(b/128995556): investigate why this additional metadata mapping is
# necessary, it must have something todo with the lifecycle of the request.
# Make sure we populate tags mapping structures.
self.prepare_metadata()

tensor_events = [] # List of tuples (meta, tensor) that contain tag.
for instance_tag in self._instance_tags(run, tag):
tensors = self._multiplexer.Tensors(run, instance_tag)
Expand Down
11 changes: 0 additions & 11 deletions tensorboard/plugins/mesh/mesh_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,6 @@ def testsEventsAlwaysSortedByStep(self):
# belong to the same mesh.
self.assertLessEqual(metadata[i - 1]["step"], metadata[i]["step"])

@mock.patch.object(
event_multiplexer.EventMultiplexer,
"PluginRunToTagToContent",
return_value={"bar": {"foo": "".encode("utf-8")}},
)
def testMetadataComputedOnce(self, run_to_tag_mock):
"""Tests that metadata mapping computed once."""
self.plugin.prepare_metadata()
self.plugin.prepare_metadata()
self.assertEqual(1, run_to_tag_mock.call_count)

def testIsActive(self):
self.assertTrue(self.plugin.is_active())

Expand Down