diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index 54c7f04038..ae3f6ba49f 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -109,76 +109,103 @@ def list_runs(self, experiment_id): def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): self._validate_experiment_id(experiment_id) - run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) - return self._list( - provider.ScalarTimeSeries, - run_tag_content, - run_tag_filter, - summary_pb2.DATA_CLASS_SCALAR, + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR ) + return self._list(provider.ScalarTimeSeries, index) def read_scalars( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): + self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) - index = self.list_scalars( - experiment_id, plugin_name, run_tag_filter=run_tag_filter + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR ) return self._read(_convert_scalar_event, index, downsample) def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): self._validate_experiment_id(experiment_id) - run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) - return self._list( - provider.TensorTimeSeries, - run_tag_content, - run_tag_filter, - summary_pb2.DATA_CLASS_TENSOR, + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR ) + return self._list(provider.TensorTimeSeries, index) def read_tensors( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): + self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) - index = self.list_tensors( - experiment_id, plugin_name, run_tag_filter=run_tag_filter + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR ) return self._read(_convert_tensor_event, index, downsample) - def _list( - self, - construct_time_series, - run_tag_content, - run_tag_filter, - data_class_filter, - ): - """Helper to list scalar or tensor time series. + def _index(self, plugin_name, run_tag_filter, data_class_filter): + """List time series and metadata matching the given filters. + + This is like `_list`, but doesn't traverse `Tensors(...)` to + compute metadata that's not always needed. Args: - construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`. - run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`. - run_tag_filter: As given by the client; may be `None`. - data_class_filter: A `summary_pb2.DataClass` value. Only time - series of this data class will be returned. + plugin_name: A string plugin name filter (required). + run_tag_filter: An `provider.RunTagFilter`, or `None`. + data_class_filter: A `summary_pb2.DataClass` filter (required). Returns: - A list of objects of type given by `construct_time_series`, - suitable to be returned from `list_scalars` or `list_tensors`. + A nested dict `d` such that `d[run][tag]` is a + `SummaryMetadata` proto. """ - result = {} if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) - for (run, tag_to_content) in six.iteritems(run_tag_content): + runs = run_tag_filter.runs + tags = run_tag_filter.tags + + # Optimization for a common case, reading a single time series. + if runs and len(runs) == 1 and tags and len(tags) == 1: + (run,) = runs + (tag,) = tags + try: + metadata = self._multiplexer.SummaryMetadata(run, tag) + except KeyError: + return {} + all_metadata = {run: {tag: metadata}} + else: + all_metadata = self._multiplexer.AllSummaryMetadata() + + result = {} + for (run, tag_to_metadata) in all_metadata.items(): + if runs is not None and run not in runs: + continue result_for_run = {} - for tag in tag_to_content: - if not self._test_run_tag(run_tag_filter, run, tag): + for (tag, metadata) in tag_to_metadata.items(): + if tags is not None and tag not in tags: continue - if ( - self._multiplexer.SummaryMetadata(run, tag).data_class - != data_class_filter - ): + if metadata.data_class != data_class_filter: + continue + if metadata.plugin_data.plugin_name != plugin_name: continue result[run] = result_for_run + result_for_run[tag] = metadata + + return result + + def _list(self, construct_time_series, index): + """Helper to list scalar or tensor time series. + + Args: + construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`. + index: The result of `self._index(...)`. + + Returns: + A list of objects of type given by `construct_time_series`, + suitable to be returned from `list_scalars` or `list_tensors`. + """ + result = {} + for (run, tag_to_metadata) in index.items(): + result_for_run = {} + result[run] = result_for_run + for (tag, summary_metadata) in tag_to_metadata.items(): max_step = None max_wall_time = None for event in self._multiplexer.Tensors(run, tag): @@ -202,7 +229,7 @@ def _read(self, convert_event, index, downsample): Args: convert_event: Takes `plugin_event_accumulator.TensorEvent` to either `provider.ScalarDatum` or `provider.TensorDatum`. - index: The result of `list_scalars` or `list_tensors`. + index: The result of `self._index(...)`. downsample: Non-negative `int`; how many samples to return per time series. @@ -224,23 +251,14 @@ def list_blob_sequences( self, experiment_id, plugin_name, run_tag_filter=None ): self._validate_experiment_id(experiment_id) - if run_tag_filter is None: - run_tag_filter = provider.RunTagFilter(runs=None, tags=None) - + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE + ) result = {} - run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) - for (run, tag_to_content) in six.iteritems(run_tag_content): + for (run, tag_to_metadata) in index.items(): result_for_run = {} - for tag in tag_to_content: - if not self._test_run_tag(run_tag_filter, run, tag): - continue - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - if ( - summary_metadata.data_class - != summary_pb2.DATA_CLASS_BLOB_SEQUENCE - ): - continue - result[run] = result_for_run + result[run] = result_for_run + for (tag, metadata) in tag_to_metadata.items(): max_step = None max_wall_time = None max_length = None @@ -256,9 +274,9 @@ def list_blob_sequences( max_step=max_step, max_wall_time=max_wall_time, max_length=max_length, - plugin_content=summary_metadata.plugin_data.content, - description=summary_metadata.summary_description, - display_name=summary_metadata.display_name, + plugin_content=metadata.plugin_data.content, + description=metadata.summary_description, + display_name=metadata.display_name, ) return result @@ -267,14 +285,14 @@ def read_blob_sequences( ): self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) - index = self.list_blob_sequences( - experiment_id, plugin_name, run_tag_filter=run_tag_filter + index = self._index( + plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE ) result = {} - for (run, tags_for_run) in six.iteritems(index): + for (run, tags) in six.iteritems(index): result_for_run = {} result[run] = result_for_run - for (tag, metadata) in six.iteritems(tags_for_run): + for tag in tags: events = self._multiplexer.Tensors(run, tag) data_by_step = {} for event in events: diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py index 54ed732e10..5455a1ddd1 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py @@ -283,6 +283,15 @@ def SummaryMetadata(self, tag): """ return self.summary_metadata[tag] + def AllSummaryMetadata(self): + """Return summary metadata for all tags. + + Returns: + A dict `d` such that `d[tag]` is a `SummaryMetadata` proto for + the keyed tag. + """ + return dict(self.summary_metadata) + def _ProcessEvent(self, event): """Called whenever an event is loaded.""" event = data_compat.migrate_event(event) diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer.py b/tensorboard/backend/event_processing/plugin_event_multiplexer.py index 1e2ab7bbc0..d2d52d9486 100644 --- a/tensorboard/backend/event_processing/plugin_event_multiplexer.py +++ b/tensorboard/backend/event_processing/plugin_event_multiplexer.py @@ -456,6 +456,21 @@ def SummaryMetadata(self, run, tag): accumulator = self.GetAccumulator(run) return accumulator.SummaryMetadata(tag) + def AllSummaryMetadata(self): + """Return summary metadata for all time series. + + Returns: + A nested dict `d` such that `d[run][tag]` is a + `SummaryMetadata` proto for the keyed time series. + """ + with self._accumulators_mutex: + # To avoid nested locks, we construct a copy of the run-accumulator map + items = list(six.iteritems(self._accumulators)) + return { + run_name: accumulator.AllSummaryMetadata() + for run_name, accumulator in items + } + def Runs(self): """Return all the run names in the `EventMultiplexer`.