Skip to content
56 changes: 43 additions & 13 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,51 @@ def experiment(self):
return experiment

@property
def multiplexer(self):
def _deprecated_multiplexer(self):
return self._tb_context.multiplexer

@property
def multiplexer(self):
raise NotImplementedError("Do not read `Context.multiplexer` directly")

@property
def tb_context(self):
return self._tb_context

def hparams_metadata(self):
"""Reads summary metadata for all hparams time series.

Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._deprecated_multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)

def scalars_metadata(self):
"""Reads summary metadata for all scalar time series.

Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._deprecated_multiplexer.PluginRunToTagToContent(
scalar_metadata.PLUGIN_NAME
)

def read_scalars(self, run, tag):
"""Reads values for a given scalar time series.

Args:
run: String.
tag: String.

Returns:
A list of `plugin_event_accumulator.TensorEvent` values.
"""
return self._deprecated_multiplexer.Tensors(run, tag)

def _find_experiment_tag(self):
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
tag.
Expand All @@ -96,9 +134,7 @@ def _find_experiment_tag(self):
"""
with self._experiment_from_tag_lock:
if self._experiment_from_tag is None:
mapping = self.multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)
mapping = self.hparams_metadata()
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
self._experiment_from_tag = metadata.parse_experiment_plugin_data(
Expand Down Expand Up @@ -131,9 +167,7 @@ def _compute_hparam_infos(self):
Returns:
A list of api_pb2.HParamInfo messages.
"""
run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)
run_to_tag_to_content = self.hparams_metadata()
# Construct a dict mapping an hparam name to its list of values.
hparams = collections.defaultdict(list)
for tag_to_content in run_to_tag_to_content.values():
Expand Down Expand Up @@ -236,9 +270,7 @@ def _compute_metric_names(self):
"""
session_runs = self._build_session_runs_set()
metric_names_set = set()
run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent(
scalar_metadata.PLUGIN_NAME
)
run_to_tag_to_content = self.scalars_metadata()
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
session = _find_longest_parent_path(session_runs, run)
if not session:
Expand All @@ -258,9 +290,7 @@ def _compute_metric_names(self):

def _build_session_runs_set(self):
result = set()
run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)
run_to_tag_to_content = self.hparams_metadata()
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG in tag_to_content:
result.add(run)
Expand Down
6 changes: 2 additions & 4 deletions tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def _build_session_groups(self):
# in the 'groups_by_name' dict. We create the SessionGroup object, if this
# is the first session of that group we encounter.
groups_by_name = {}
run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)
run_to_tag_to_content = self._context.hparams_metadata()
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
continue
Expand Down Expand Up @@ -159,7 +157,7 @@ def _build_session_metric_values(self, session_name):
metric_name = metric_info.name
try:
metric_eval = metrics.last_metric_eval(
self._context.multiplexer, session_name, metric_name
self._context, session_name, metric_name
)
except KeyError:
# It's ok if we don't find the metric in the session.
Expand Down
8 changes: 4 additions & 4 deletions tensorboard/plugins/hparams/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import six

from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.util import tensor_util


Expand All @@ -46,12 +47,11 @@ def run_tag_from_session_and_metric(session_name, metric_name):
return run, tag


def last_metric_eval(multiplexer, session_name, metric_name):
def last_metric_eval(context, session_name, metric_name):
"""Returns the last evaluations of the given metric at the given session.

Args:
multiplexer: The EventMultiplexer instance allowing access to
the exported summary data.
context: A `backend_context.Context` value.
session_name: String. The session name for which to get the metric
evaluations.
metric_name: api_pb2.MetricName proto. The name of the metric to use.
Expand All @@ -68,7 +68,7 @@ def last_metric_eval(multiplexer, session_name, metric_name):
"""
try:
run, tag = run_tag_from_session_and_metric(session_name, metric_name)
tensor_events = multiplexer.Tensors(run=run, tag=tag)
tensor_events = context.read_scalars(run, tag)
except KeyError as e:
raise KeyError(
"Can't find metric %s for session: %s. Underlying error message: %s"
Expand Down