Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 47 additions & 34 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ def __init__(self, tb_context, max_domain_discrete_len=10):
Typically, only tests should specify a value for this parameter.
"""
self._tb_context = tb_context
self._experiment_from_tag = None
self._experiment_from_tag_lock = threading.Lock()
self._max_domain_discrete_len = max_domain_discrete_len

def experiment(self):
def experiment(self, experiment_id):
"""Returns the experiment protobuffer defining the experiment.

This method first attempts to find a metadata.EXPERIMENT_TAG tag and
Expand All @@ -72,9 +70,9 @@ def experiment(self):
protobuffer can be built (possibly, because the event data has not been
completely loaded yet), returns None.
"""
experiment = self._find_experiment_tag()
experiment = self._find_experiment_tag(experiment_id)
if experiment is None:
return self._compute_experiment_from_runs()
return self._compute_experiment_from_runs(experiment_id)
return experiment

@property
Expand All @@ -89,72 +87,87 @@ def multiplexer(self):
def tb_context(self):
return self._tb_context

def hparams_metadata(self):
def hparams_metadata(self, experiment_id):
"""Reads summary metadata for all hparams time series.

Args:
experiment_id: String, from `plugin_util.experiment_id`.

Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)

def scalars_metadata(self):
def scalars_metadata(self, experiment_id):
"""Reads summary metadata for all scalar time series.

Args:
experiment_id: String, from `plugin_util.experiment_id`.

Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
scalar_metadata.PLUGIN_NAME
)

def read_scalars(self, run, tag):
def read_scalars(self, experiment_id, run, tag):
"""Reads values for a given scalar time series.

Args:
experiment_id: String.
run: String.
tag: String.

Returns:
A list of `plugin_event_accumulator.TensorEvent` values.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.Tensors(run, tag)

def _find_experiment_tag(self):
def _find_experiment_tag(self, experiment_id):
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
tag.

Caches the experiment if it was found.

Returns:
The experiment or None if no such experiment is found.
"""
with self._experiment_from_tag_lock:
if self._experiment_from_tag is None:
mapping = self.hparams_metadata()
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
self._experiment_from_tag = metadata.parse_experiment_plugin_data(
tag_to_content[metadata.EXPERIMENT_TAG]
)
break
return self._experiment_from_tag

def _compute_experiment_from_runs(self):
mapping = self.hparams_metadata(experiment_id)
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
experiment = metadata.parse_experiment_plugin_data(
tag_to_content[metadata.EXPERIMENT_TAG]
)
return experiment
return None

def _compute_experiment_from_runs(self, experiment_id):
"""Computes a minimal Experiment protocol buffer by scanning the
runs."""
hparam_infos = self._compute_hparam_infos()
hparam_infos = self._compute_hparam_infos(experiment_id)
if not hparam_infos:
return None
metric_infos = self._compute_metric_infos()
metric_infos = self._compute_metric_infos(experiment_id)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)

def _compute_hparam_infos(self):
def _compute_hparam_infos(self, experiment_id):
"""Computes a list of api_pb2.HParamInfo from the current run, tag
info.

Expand All @@ -167,7 +180,7 @@ def _compute_hparam_infos(self):
Returns:
A list of api_pb2.HParamInfo messages.
"""
run_to_tag_to_content = self.hparams_metadata()
run_to_tag_to_content = self.hparams_metadata(experiment_id)
# Construct a dict mapping an hparam name to its list of values.
hparams = collections.defaultdict(list)
for tag_to_content in run_to_tag_to_content.values():
Expand Down Expand Up @@ -236,13 +249,13 @@ def _compute_hparam_info_from_values(self, name, values):

return result

def _compute_metric_infos(self):
def _compute_metric_infos(self, experiment_id):
return (
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names()
for tag, group in self._compute_metric_names(experiment_id)
)

def _compute_metric_names(self):
def _compute_metric_names(self, experiment_id):
"""Computes the list of metric names from all the scalar (run, tag)
pairs.

Expand All @@ -268,9 +281,9 @@ def _compute_metric_names(self):
A python list containing pairs. Each pair is a (tag, group) pair
representing a metric name used in some session.
"""
session_runs = self._build_session_runs_set()
session_runs = self._build_session_runs_set(experiment_id)
metric_names_set = set()
run_to_tag_to_content = self.scalars_metadata()
run_to_tag_to_content = self.scalars_metadata(experiment_id)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
session = _find_longest_parent_path(session_runs, run)
if not session:
Expand All @@ -288,9 +301,9 @@ def _compute_metric_names(self):
metric_names_list.sort()
return metric_names_list

def _build_session_runs_set(self):
def _build_session_runs_set(self, experiment_id):
result = set()
run_to_tag_to_content = self.hparams_metadata()
run_to_tag_to_content = self.hparams_metadata(experiment_id)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG in tag_to_content:
result.add(run)
Expand Down
8 changes: 4 additions & 4 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_experiment_with_experiment_tag(self):
}
}
ctxt = backend_context.Context(self._mock_tb_context)
self.assertProtoEquals(experiment, ctxt.experiment())
self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123"))

def test_experiment_without_experiment_tag(self):
self.session_1_start_info_ = """
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_experiment_without_experiment_tag(self):
}
"""
ctxt = backend_context.Context(self._mock_tb_context)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

Expand Down Expand Up @@ -230,7 +230,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
}
"""
ctxt = backend_context.Context(self._mock_tb_context)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

Expand Down Expand Up @@ -285,7 +285,7 @@ def test_experiment_without_experiment_tag_many_distinct_values(self):
ctxt = backend_context.Context(
self._mock_tb_context, max_domain_discrete_len=1
)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

Expand Down
6 changes: 4 additions & 2 deletions tensorboard/plugins/hparams/get_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,23 @@
class Handler(object):
"""Handles a GetExperiment request."""

def __init__(self, context):
def __init__(self, context, experiment_id):
"""Constructor.

Args:
context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
"""
self._context = context
self._experiment_id = experiment_id

def run(self):
"""Handles the request specified on construction.

Returns:
An Experiment object.
"""
experiment = self._context.experiment()
experiment = self._context.experiment(self._experiment_id)
if experiment is None:
raise error.HParamsError(
"Can't find an HParams-plugin experiment data in"
Expand Down
17 changes: 11 additions & 6 deletions tensorboard/plugins/hparams/hparams_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def frontend_metadata(self):
# ---- /download_data- -------------------------------------------------------
@wrappers.Request.application
def download_data_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
response_format = request.args.get("format")
columns_visibility = json.loads(
Expand All @@ -91,9 +92,11 @@ def download_data_route(self, request):
request, api_pb2.ListSessionGroupsRequest
)
session_groups = list_session_groups.Handler(
self._context, request_proto
self._context, experiment_id, request_proto
).run()
experiment = get_experiment.Handler(
self._context, experiment_id
).run()
experiment = get_experiment.Handler(self._context).run()
body, mime_type = download_data.Handler(
self._context,
experiment,
Expand All @@ -109,6 +112,7 @@ def download_data_route(self, request):
# ---- /experiment -----------------------------------------------------------
@wrappers.Request.application
def get_experiment_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
# This backend currently ignores the request parameters, but (for a POST)
# we must advance the input stream to skip them -- otherwise the next HTTP
Expand All @@ -117,7 +121,7 @@ def get_experiment_route(self, request):
return http_util.Respond(
request,
json_format.MessageToJson(
get_experiment.Handler(self._context).run(),
get_experiment.Handler(self._context, experiment_id).run(),
including_default_value_fields=True,
),
"application/json",
Expand All @@ -129,6 +133,7 @@ def get_experiment_route(self, request):
# ---- /session_groups -------------------------------------------------------
@wrappers.Request.application
def list_session_groups_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
Expand All @@ -137,7 +142,7 @@ def list_session_groups_route(self, request):
request,
json_format.MessageToJson(
list_session_groups.Handler(
self._context, request_proto
self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
Expand All @@ -150,7 +155,7 @@ def list_session_groups_route(self, request):
# ---- /metric_evals ---------------------------------------------------------
@wrappers.Request.application
def list_metric_evals_route(self, request):
experiment = plugin_util.experiment_id(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListMetricEvalsRequest
Expand All @@ -162,7 +167,7 @@ def list_metric_evals_route(self, request):
request,
json.dumps(
list_metric_evals.Handler(
request_proto, scalars_plugin, experiment
request_proto, scalars_plugin, experiment_id
).run()
),
"application/json",
Expand Down
15 changes: 11 additions & 4 deletions tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,22 @@
class Handler(object):
"""Handles a ListSessionGroups request."""

def __init__(self, context, request):
def __init__(self, context, experiment_id, request):
"""Constructor.

Args:
context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A ListSessionGroupsRequest protobuf.
"""
self._context = context
self._experiment_id = experiment_id
self._request = request
self._extractors = _create_extractors(request.col_params)
self._filters = _create_filters(request.col_params, self._extractors)
# Since an context.experiment() call may search through all the runs, we
# cache it here.
self._experiment = context.experiment()
self._experiment = context.experiment(experiment_id)

def run(self):
"""Handles the request specified on construction.
Expand All @@ -72,7 +74,9 @@ def _build_session_groups(self):
# in the 'groups_by_name' dict. We create the SessionGroup object, if this
# is the first session of that group we encounter.
groups_by_name = {}
run_to_tag_to_content = self._context.hparams_metadata()
run_to_tag_to_content = self._context.hparams_metadata(
self._experiment_id
)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
continue
Expand Down Expand Up @@ -157,7 +161,10 @@ def _build_session_metric_values(self, session_name):
metric_name = metric_info.name
try:
metric_eval = metrics.last_metric_eval(
self._context, session_name, metric_name
self._context,
self._experiment_id,
session_name,
metric_name,
)
except KeyError:
# It's ok if we don't find the metric in the session.
Expand Down
4 changes: 3 additions & 1 deletion tensorboard/plugins/hparams/list_session_groups_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,9 @@ def _run_handler(self, request):
request_proto = api_pb2.ListSessionGroupsRequest()
text_format.Merge(request, request_proto)
handler = list_session_groups.Handler(
backend_context.Context(self._mock_tb_context), request_proto
context=backend_context.Context(self._mock_tb_context),
experiment_id="123",
request=request_proto,
)
response = handler.run()
# Sort the metric values repeated field in each session group to
Expand Down
Loading