diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 630c50caf2..8cd93edf22 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -78,13 +78,51 @@ def experiment(self): return experiment @property - def multiplexer(self): + def _deprecated_multiplexer(self): return self._tb_context.multiplexer + @property + def multiplexer(self): + raise NotImplementedError("Do not read `Context.multiplexer` directly") + @property def tb_context(self): return self._tb_context + def hparams_metadata(self): + """Reads summary metadata for all hparams time series. + + Returns: + A dict `d` such that `d[run][tag]` is a `bytes` value with the + summary metadata content for the keyed time series. + """ + return self._deprecated_multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + + def scalars_metadata(self): + """Reads summary metadata for all scalar time series. + + Returns: + A dict `d` such that `d[run][tag]` is a `bytes` value with the + summary metadata content for the keyed time series. + """ + return self._deprecated_multiplexer.PluginRunToTagToContent( + scalar_metadata.PLUGIN_NAME + ) + + def read_scalars(self, run, tag): + """Reads values for a given scalar time series. + + Args: + run: String. + tag: String. + + Returns: + A list of `plugin_event_accumulator.TensorEvent` values. + """ + return self._deprecated_multiplexer.Tensors(run, tag) + def _find_experiment_tag(self): """Finds the experiment associcated with the metadata.EXPERIMENT_TAG tag. @@ -96,9 +134,7 @@ def _find_experiment_tag(self): """ with self._experiment_from_tag_lock: if self._experiment_from_tag is None: - mapping = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME - ) + mapping = self.hparams_metadata() for tag_to_content in mapping.values(): if metadata.EXPERIMENT_TAG in tag_to_content: self._experiment_from_tag = metadata.parse_experiment_plugin_data( @@ -131,9 +167,7 @@ def _compute_hparam_infos(self): Returns: A list of api_pb2.HParamInfo messages. """ - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME - ) + run_to_tag_to_content = self.hparams_metadata() # Construct a dict mapping an hparam name to its list of values. hparams = collections.defaultdict(list) for tag_to_content in run_to_tag_to_content.values(): @@ -236,9 +270,7 @@ def _compute_metric_names(self): """ session_runs = self._build_session_runs_set() metric_names_set = set() - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - scalar_metadata.PLUGIN_NAME - ) + run_to_tag_to_content = self.scalars_metadata() for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): session = _find_longest_parent_path(session_runs, run) if not session: @@ -258,9 +290,7 @@ def _compute_metric_names(self): def _build_session_runs_set(self): result = set() - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME - ) + run_to_tag_to_content = self.hparams_metadata() for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG in tag_to_content: result.add(run) diff --git a/tensorboard/plugins/hparams/list_session_groups.py b/tensorboard/plugins/hparams/list_session_groups.py index b5a86281ef..9aca9d39e8 100644 --- a/tensorboard/plugins/hparams/list_session_groups.py +++ b/tensorboard/plugins/hparams/list_session_groups.py @@ -72,9 +72,7 @@ def _build_session_groups(self): # in the 'groups_by_name' dict. We create the SessionGroup object, if this # is the first session of that group we encounter. groups_by_name = {} - run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME - ) + run_to_tag_to_content = self._context.hparams_metadata() for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue @@ -159,7 +157,7 @@ def _build_session_metric_values(self, session_name): metric_name = metric_info.name try: metric_eval = metrics.last_metric_eval( - self._context.multiplexer, session_name, metric_name + self._context, session_name, metric_name ) except KeyError: # It's ok if we don't find the metric in the session. diff --git a/tensorboard/plugins/hparams/metrics.py b/tensorboard/plugins/hparams/metrics.py index f8e7b88c12..be6bffb7e8 100644 --- a/tensorboard/plugins/hparams/metrics.py +++ b/tensorboard/plugins/hparams/metrics.py @@ -23,6 +23,7 @@ import six from tensorboard.plugins.hparams import api_pb2 +from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.util import tensor_util @@ -46,12 +47,11 @@ def run_tag_from_session_and_metric(session_name, metric_name): return run, tag -def last_metric_eval(multiplexer, session_name, metric_name): +def last_metric_eval(context, session_name, metric_name): """Returns the last evaluations of the given metric at the given session. Args: - multiplexer: The EventMultiplexer instance allowing access to - the exported summary data. + context: A `backend_context.Context` value. session_name: String. The session name for which to get the metric evaluations. metric_name: api_pb2.MetricName proto. The name of the metric to use. @@ -68,7 +68,7 @@ def last_metric_eval(multiplexer, session_name, metric_name): """ try: run, tag = run_tag_from_session_and_metric(session_name, metric_name) - tensor_events = multiplexer.Tensors(run=run, tag=tag) + tensor_events = context.read_scalars(run, tag) except KeyError as e: raise KeyError( "Can't find metric %s for session: %s. Underlying error message: %s"