diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 8cd93edf22..5ba2db7e7b 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -53,11 +53,9 @@ def __init__(self, tb_context, max_domain_discrete_len=10): Typically, only tests should specify a value for this parameter. """ self._tb_context = tb_context - self._experiment_from_tag = None - self._experiment_from_tag_lock = threading.Lock() self._max_domain_discrete_len = max_domain_discrete_len - def experiment(self): + def experiment(self, experiment_id): """Returns the experiment protobuffer defining the experiment. This method first attempts to find a metadata.EXPERIMENT_TAG tag and @@ -72,9 +70,9 @@ def experiment(self): protobuffer can be built (possibly, because the event data has not been completely loaded yet), returns None. """ - experiment = self._find_experiment_tag() + experiment = self._find_experiment_tag(experiment_id) if experiment is None: - return self._compute_experiment_from_runs() + return self._compute_experiment_from_runs(experiment_id) return experiment @property @@ -89,72 +87,87 @@ def multiplexer(self): def tb_context(self): return self._tb_context - def hparams_metadata(self): + def hparams_metadata(self, experiment_id): """Reads summary metadata for all hparams time series. + Args: + experiment_id: String, from `plugin_util.experiment_id`. + Returns: A dict `d` such that `d[run][tag]` is a `bytes` value with the summary metadata content for the keyed time series. """ + assert isinstance(experiment_id, str), ( + experiment_id, + type(experiment_id), + ) return self._deprecated_multiplexer.PluginRunToTagToContent( metadata.PLUGIN_NAME ) - def scalars_metadata(self): + def scalars_metadata(self, experiment_id): """Reads summary metadata for all scalar time series. + Args: + experiment_id: String, from `plugin_util.experiment_id`. + Returns: A dict `d` such that `d[run][tag]` is a `bytes` value with the summary metadata content for the keyed time series. """ + assert isinstance(experiment_id, str), ( + experiment_id, + type(experiment_id), + ) return self._deprecated_multiplexer.PluginRunToTagToContent( scalar_metadata.PLUGIN_NAME ) - def read_scalars(self, run, tag): + def read_scalars(self, experiment_id, run, tag): """Reads values for a given scalar time series. Args: + experiment_id: String. run: String. tag: String. Returns: A list of `plugin_event_accumulator.TensorEvent` values. """ + assert isinstance(experiment_id, str), ( + experiment_id, + type(experiment_id), + ) return self._deprecated_multiplexer.Tensors(run, tag) - def _find_experiment_tag(self): + def _find_experiment_tag(self, experiment_id): """Finds the experiment associcated with the metadata.EXPERIMENT_TAG tag. - Caches the experiment if it was found. - Returns: The experiment or None if no such experiment is found. """ - with self._experiment_from_tag_lock: - if self._experiment_from_tag is None: - mapping = self.hparams_metadata() - for tag_to_content in mapping.values(): - if metadata.EXPERIMENT_TAG in tag_to_content: - self._experiment_from_tag = metadata.parse_experiment_plugin_data( - tag_to_content[metadata.EXPERIMENT_TAG] - ) - break - return self._experiment_from_tag - - def _compute_experiment_from_runs(self): + mapping = self.hparams_metadata(experiment_id) + for tag_to_content in mapping.values(): + if metadata.EXPERIMENT_TAG in tag_to_content: + experiment = metadata.parse_experiment_plugin_data( + tag_to_content[metadata.EXPERIMENT_TAG] + ) + return experiment + return None + + def _compute_experiment_from_runs(self, experiment_id): """Computes a minimal Experiment protocol buffer by scanning the runs.""" - hparam_infos = self._compute_hparam_infos() + hparam_infos = self._compute_hparam_infos(experiment_id) if not hparam_infos: return None - metric_infos = self._compute_metric_infos() + metric_infos = self._compute_metric_infos(experiment_id) return api_pb2.Experiment( hparam_infos=hparam_infos, metric_infos=metric_infos ) - def _compute_hparam_infos(self): + def _compute_hparam_infos(self, experiment_id): """Computes a list of api_pb2.HParamInfo from the current run, tag info. @@ -167,7 +180,7 @@ def _compute_hparam_infos(self): Returns: A list of api_pb2.HParamInfo messages. """ - run_to_tag_to_content = self.hparams_metadata() + run_to_tag_to_content = self.hparams_metadata(experiment_id) # Construct a dict mapping an hparam name to its list of values. hparams = collections.defaultdict(list) for tag_to_content in run_to_tag_to_content.values(): @@ -236,13 +249,13 @@ def _compute_hparam_info_from_values(self, name, values): return result - def _compute_metric_infos(self): + def _compute_metric_infos(self, experiment_id): return ( api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag)) - for tag, group in self._compute_metric_names() + for tag, group in self._compute_metric_names(experiment_id) ) - def _compute_metric_names(self): + def _compute_metric_names(self, experiment_id): """Computes the list of metric names from all the scalar (run, tag) pairs. @@ -268,9 +281,9 @@ def _compute_metric_names(self): A python list containing pairs. Each pair is a (tag, group) pair representing a metric name used in some session. """ - session_runs = self._build_session_runs_set() + session_runs = self._build_session_runs_set(experiment_id) metric_names_set = set() - run_to_tag_to_content = self.scalars_metadata() + run_to_tag_to_content = self.scalars_metadata(experiment_id) for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): session = _find_longest_parent_path(session_runs, run) if not session: @@ -288,9 +301,9 @@ def _compute_metric_names(self): metric_names_list.sort() return metric_names_list - def _build_session_runs_set(self): + def _build_session_runs_set(self, experiment_id): result = set() - run_to_tag_to_content = self.hparams_metadata() + run_to_tag_to_content = self.hparams_metadata(experiment_id) for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG in tag_to_content: result.add(run) diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index b73c60d3af..9b963f689a 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -113,7 +113,7 @@ def test_experiment_with_experiment_tag(self): } } ctxt = backend_context.Context(self._mock_tb_context) - self.assertProtoEquals(experiment, ctxt.experiment()) + self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123")) def test_experiment_without_experiment_tag(self): self.session_1_start_info_ = """ @@ -168,7 +168,7 @@ def test_experiment_without_experiment_tag(self): } """ ctxt = backend_context.Context(self._mock_tb_context) - actual_exp = ctxt.experiment() + actual_exp = ctxt.experiment(experiment_id="123") _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) @@ -230,7 +230,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self): } """ ctxt = backend_context.Context(self._mock_tb_context) - actual_exp = ctxt.experiment() + actual_exp = ctxt.experiment(experiment_id="123") _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) @@ -285,7 +285,7 @@ def test_experiment_without_experiment_tag_many_distinct_values(self): ctxt = backend_context.Context( self._mock_tb_context, max_domain_discrete_len=1 ) - actual_exp = ctxt.experiment() + actual_exp = ctxt.experiment(experiment_id="123") _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) diff --git a/tensorboard/plugins/hparams/get_experiment.py b/tensorboard/plugins/hparams/get_experiment.py index 91dbf29f52..c9abcf4ebe 100644 --- a/tensorboard/plugins/hparams/get_experiment.py +++ b/tensorboard/plugins/hparams/get_experiment.py @@ -25,13 +25,15 @@ class Handler(object): """Handles a GetExperiment request.""" - def __init__(self, context): + def __init__(self, context, experiment_id): """Constructor. Args: context: A backend_context.Context instance. + experiment_id: A string, as from `plugin_util.experiment_id`. """ self._context = context + self._experiment_id = experiment_id def run(self): """Handles the request specified on construction. @@ -39,7 +41,7 @@ def run(self): Returns: An Experiment object. """ - experiment = self._context.experiment() + experiment = self._context.experiment(self._experiment_id) if experiment is None: raise error.HParamsError( "Can't find an HParams-plugin experiment data in" diff --git a/tensorboard/plugins/hparams/hparams_plugin.py b/tensorboard/plugins/hparams/hparams_plugin.py index 769d03c3c5..94ba7cb5c3 100644 --- a/tensorboard/plugins/hparams/hparams_plugin.py +++ b/tensorboard/plugins/hparams/hparams_plugin.py @@ -82,6 +82,7 @@ def frontend_metadata(self): # ---- /download_data- ------------------------------------------------------- @wrappers.Request.application def download_data_route(self, request): + experiment_id = plugin_util.experiment_id(request.environ) try: response_format = request.args.get("format") columns_visibility = json.loads( @@ -91,9 +92,11 @@ def download_data_route(self, request): request, api_pb2.ListSessionGroupsRequest ) session_groups = list_session_groups.Handler( - self._context, request_proto + self._context, experiment_id, request_proto + ).run() + experiment = get_experiment.Handler( + self._context, experiment_id ).run() - experiment = get_experiment.Handler(self._context).run() body, mime_type = download_data.Handler( self._context, experiment, @@ -109,6 +112,7 @@ def download_data_route(self, request): # ---- /experiment ----------------------------------------------------------- @wrappers.Request.application def get_experiment_route(self, request): + experiment_id = plugin_util.experiment_id(request.environ) try: # This backend currently ignores the request parameters, but (for a POST) # we must advance the input stream to skip them -- otherwise the next HTTP @@ -117,7 +121,7 @@ def get_experiment_route(self, request): return http_util.Respond( request, json_format.MessageToJson( - get_experiment.Handler(self._context).run(), + get_experiment.Handler(self._context, experiment_id).run(), including_default_value_fields=True, ), "application/json", @@ -129,6 +133,7 @@ def get_experiment_route(self, request): # ---- /session_groups ------------------------------------------------------- @wrappers.Request.application def list_session_groups_route(self, request): + experiment_id = plugin_util.experiment_id(request.environ) try: request_proto = _parse_request_argument( request, api_pb2.ListSessionGroupsRequest @@ -137,7 +142,7 @@ def list_session_groups_route(self, request): request, json_format.MessageToJson( list_session_groups.Handler( - self._context, request_proto + self._context, experiment_id, request_proto ).run(), including_default_value_fields=True, ), @@ -150,7 +155,7 @@ def list_session_groups_route(self, request): # ---- /metric_evals --------------------------------------------------------- @wrappers.Request.application def list_metric_evals_route(self, request): - experiment = plugin_util.experiment_id(request.environ) + experiment_id = plugin_util.experiment_id(request.environ) try: request_proto = _parse_request_argument( request, api_pb2.ListMetricEvalsRequest @@ -162,7 +167,7 @@ def list_metric_evals_route(self, request): request, json.dumps( list_metric_evals.Handler( - request_proto, scalars_plugin, experiment + request_proto, scalars_plugin, experiment_id ).run() ), "application/json", diff --git a/tensorboard/plugins/hparams/list_session_groups.py b/tensorboard/plugins/hparams/list_session_groups.py index 9aca9d39e8..5d31fa6f52 100644 --- a/tensorboard/plugins/hparams/list_session_groups.py +++ b/tensorboard/plugins/hparams/list_session_groups.py @@ -35,20 +35,22 @@ class Handler(object): """Handles a ListSessionGroups request.""" - def __init__(self, context, request): + def __init__(self, context, experiment_id, request): """Constructor. Args: context: A backend_context.Context instance. + experiment_id: A string, as from `plugin_util.experiment_id`. request: A ListSessionGroupsRequest protobuf. """ self._context = context + self._experiment_id = experiment_id self._request = request self._extractors = _create_extractors(request.col_params) self._filters = _create_filters(request.col_params, self._extractors) # Since an context.experiment() call may search through all the runs, we # cache it here. - self._experiment = context.experiment() + self._experiment = context.experiment(experiment_id) def run(self): """Handles the request specified on construction. @@ -72,7 +74,9 @@ def _build_session_groups(self): # in the 'groups_by_name' dict. We create the SessionGroup object, if this # is the first session of that group we encounter. groups_by_name = {} - run_to_tag_to_content = self._context.hparams_metadata() + run_to_tag_to_content = self._context.hparams_metadata( + self._experiment_id + ) for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue @@ -157,7 +161,10 @@ def _build_session_metric_values(self, session_name): metric_name = metric_info.name try: metric_eval = metrics.last_metric_eval( - self._context, session_name, metric_name + self._context, + self._experiment_id, + session_name, + metric_name, ) except KeyError: # It's ok if we don't find the metric in the session. diff --git a/tensorboard/plugins/hparams/list_session_groups_test.py b/tensorboard/plugins/hparams/list_session_groups_test.py index 8c58e5bde6..8395c96448 100644 --- a/tensorboard/plugins/hparams/list_session_groups_test.py +++ b/tensorboard/plugins/hparams/list_session_groups_test.py @@ -1132,7 +1132,9 @@ def _run_handler(self, request): request_proto = api_pb2.ListSessionGroupsRequest() text_format.Merge(request, request_proto) handler = list_session_groups.Handler( - backend_context.Context(self._mock_tb_context), request_proto + context=backend_context.Context(self._mock_tb_context), + experiment_id="123", + request=request_proto, ) response = handler.run() # Sort the metric values repeated field in each session group to diff --git a/tensorboard/plugins/hparams/metrics.py b/tensorboard/plugins/hparams/metrics.py index be6bffb7e8..83212e676a 100644 --- a/tensorboard/plugins/hparams/metrics.py +++ b/tensorboard/plugins/hparams/metrics.py @@ -47,11 +47,12 @@ def run_tag_from_session_and_metric(session_name, metric_name): return run, tag -def last_metric_eval(context, session_name, metric_name): +def last_metric_eval(context, experiment_id, session_name, metric_name): """Returns the last evaluations of the given metric at the given session. Args: context: A `backend_context.Context` value. + experiment_id: String, as from `plugin_util.experiment_id`. session_name: String. The session name for which to get the metric evaluations. metric_name: api_pb2.MetricName proto. The name of the metric to use. @@ -68,7 +69,7 @@ def last_metric_eval(context, session_name, metric_name): """ try: run, tag = run_tag_from_session_and_metric(session_name, metric_name) - tensor_events = context.read_scalars(run, tag) + tensor_events = context.read_scalars(experiment_id, run, tag) except KeyError as e: raise KeyError( "Can't find metric %s for session: %s. Underlying error message: %s"