Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 80 additions & 62 deletions tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,76 +109,103 @@ def list_runs(self, experiment_id):

def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
self._validate_experiment_id(experiment_id)
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
return self._list(
provider.ScalarTimeSeries,
run_tag_content,
run_tag_filter,
summary_pb2.DATA_CLASS_SCALAR,
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._list(provider.ScalarTimeSeries, index)

def read_scalars(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self.list_scalars(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._read(_convert_scalar_event, index, downsample)

def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
self._validate_experiment_id(experiment_id)
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
return self._list(
provider.TensorTimeSeries,
run_tag_content,
run_tag_filter,
summary_pb2.DATA_CLASS_TENSOR,
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._list(provider.TensorTimeSeries, index)

def read_tensors(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self.list_tensors(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._read(_convert_tensor_event, index, downsample)

def _list(
self,
construct_time_series,
run_tag_content,
run_tag_filter,
data_class_filter,
):
"""Helper to list scalar or tensor time series.
def _index(self, plugin_name, run_tag_filter, data_class_filter):
"""List time series and metadata matching the given filters.

This is like `_list`, but doesn't traverse `Tensors(...)` to
compute metadata that's not always needed.

Args:
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`.
run_tag_filter: As given by the client; may be `None`.
data_class_filter: A `summary_pb2.DataClass` value. Only time
series of this data class will be returned.
plugin_name: A string plugin name filter (required).
run_tag_filter: An `provider.RunTagFilter`, or `None`.
data_class_filter: A `summary_pb2.DataClass` filter (required).

Returns:
A list of objects of type given by `construct_time_series`,
suitable to be returned from `list_scalars` or `list_tensors`.
A nested dict `d` such that `d[run][tag]` is a
`SummaryMetadata` proto.
"""
result = {}
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
for (run, tag_to_content) in six.iteritems(run_tag_content):
runs = run_tag_filter.runs
tags = run_tag_filter.tags

# Optimization for a common case, reading a single time series.
if runs and len(runs) == 1 and tags and len(tags) == 1:
(run,) = runs
(tag,) = tags
try:
metadata = self._multiplexer.SummaryMetadata(run, tag)
except KeyError:
return {}
all_metadata = {run: {tag: metadata}}
else:
all_metadata = self._multiplexer.AllSummaryMetadata()

result = {}
for (run, tag_to_metadata) in all_metadata.items():
if runs is not None and run not in runs:
continue
result_for_run = {}
for tag in tag_to_content:
if not self._test_run_tag(run_tag_filter, run, tag):
for (tag, metadata) in tag_to_metadata.items():
if tags is not None and tag not in tags:
continue
if (
self._multiplexer.SummaryMetadata(run, tag).data_class
!= data_class_filter
):
if metadata.data_class != data_class_filter:
continue
if metadata.plugin_data.plugin_name != plugin_name:
continue
result[run] = result_for_run
result_for_run[tag] = metadata

return result

def _list(self, construct_time_series, index):
"""Helper to list scalar or tensor time series.

Args:
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
index: The result of `self._index(...)`.

Returns:
A list of objects of type given by `construct_time_series`,
suitable to be returned from `list_scalars` or `list_tensors`.
"""
result = {}
for (run, tag_to_metadata) in index.items():
result_for_run = {}
result[run] = result_for_run
for (tag, summary_metadata) in tag_to_metadata.items():
max_step = None
max_wall_time = None
for event in self._multiplexer.Tensors(run, tag):
Expand All @@ -202,7 +229,7 @@ def _read(self, convert_event, index, downsample):
Args:
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
either `provider.ScalarDatum` or `provider.TensorDatum`.
index: The result of `list_scalars` or `list_tensors`.
index: The result of `self._index(...)`.
downsample: Non-negative `int`; how many samples to return per
time series.

Expand All @@ -224,23 +251,14 @@ def list_blob_sequences(
self, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_experiment_id(experiment_id)
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)

index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
for (run, tag_to_content) in six.iteritems(run_tag_content):
for (run, tag_to_metadata) in index.items():
result_for_run = {}
for tag in tag_to_content:
if not self._test_run_tag(run_tag_filter, run, tag):
continue
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
if (
summary_metadata.data_class
!= summary_pb2.DATA_CLASS_BLOB_SEQUENCE
):
continue
result[run] = result_for_run
result[run] = result_for_run
for (tag, metadata) in tag_to_metadata.items():
max_step = None
max_wall_time = None
max_length = None
Expand All @@ -256,9 +274,9 @@ def list_blob_sequences(
max_step=max_step,
max_wall_time=max_wall_time,
max_length=max_length,
plugin_content=summary_metadata.plugin_data.content,
description=summary_metadata.summary_description,
display_name=summary_metadata.display_name,
plugin_content=metadata.plugin_data.content,
description=metadata.summary_description,
display_name=metadata.display_name,
)
return result

Expand All @@ -267,14 +285,14 @@ def read_blob_sequences(
):
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self.list_blob_sequences(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
for (run, tags_for_run) in six.iteritems(index):
for (run, tags) in six.iteritems(index):
result_for_run = {}
result[run] = result_for_run
for (tag, metadata) in six.iteritems(tags_for_run):
for tag in tags:
events = self._multiplexer.Tensors(run, tag)
data_by_step = {}
for event in events:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ def SummaryMetadata(self, tag):
"""
return self.summary_metadata[tag]

def AllSummaryMetadata(self):
"""Return summary metadata for all tags.

Returns:
A dict `d` such that `d[tag]` is a `SummaryMetadata` proto for
the keyed tag.
"""
return dict(self.summary_metadata)

def _ProcessEvent(self, event):
"""Called whenever an event is loaded."""
event = data_compat.migrate_event(event)
Expand Down
15 changes: 15 additions & 0 deletions tensorboard/backend/event_processing/plugin_event_multiplexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,21 @@ def SummaryMetadata(self, run, tag):
accumulator = self.GetAccumulator(run)
return accumulator.SummaryMetadata(tag)

def AllSummaryMetadata(self):
"""Return summary metadata for all time series.

Returns:
A nested dict `d` such that `d[run][tag]` is a
`SummaryMetadata` proto for the keyed time series.
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(six.iteritems(self._accumulators))
return {
run_name: accumulator.AllSummaryMetadata()
for run_name, accumulator in items
}

def Runs(self):
"""Return all the run names in the `EventMultiplexer`.

Expand Down