diff --git a/tensorboard/plugins/hparams/BUILD b/tensorboard/plugins/hparams/BUILD index eac0b8bfed..ddbc80422b 100644 --- a/tensorboard/plugins/hparams/BUILD +++ b/tensorboard/plugins/hparams/BUILD @@ -42,6 +42,7 @@ py_library( ":protos_all_py_pb2", "//tensorboard:plugin_util", "//tensorboard/backend:http_util", + "//tensorboard/data:provider", "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:scalars_plugin", @@ -72,14 +73,18 @@ py_test( py_test( name = "list_session_groups_test", size = "small", + timeout = "moderate", srcs = [ "list_session_groups_test.py", ], deps = [ ":hparams_plugin", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend/event_processing:data_provider", "//tensorboard/backend/event_processing:event_accumulator", "//tensorboard/backend/event_processing:event_multiplexer", + "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/scalar:metadata", "@org_pythonhosted_mock", ], ) @@ -106,8 +111,11 @@ py_test( deps = [ ":hparams_plugin", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend/event_processing:data_provider", "//tensorboard/backend/event_processing:event_accumulator", "//tensorboard/backend/event_processing:event_multiplexer", + "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/scalar:metadata", "@org_pythonhosted_mock", ], ) diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 5ba2db7e7b..dc48589e65 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -25,10 +25,15 @@ import six +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.data import provider from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import metadata from google.protobuf import json_format from tensorboard.plugins.scalar import metadata as scalar_metadata +from tensorboard.util import tensor_util class Context(object): @@ -75,34 +80,37 @@ def experiment(self, experiment_id): return self._compute_experiment_from_runs(experiment_id) return experiment - @property - def _deprecated_multiplexer(self): - return self._tb_context.multiplexer - - @property - def multiplexer(self): - raise NotImplementedError("Do not read `Context.multiplexer` directly") - @property def tb_context(self): return self._tb_context - def hparams_metadata(self, experiment_id): + def _convert_plugin_metadata(self, data_provider_output): + return { + run: { + tag: time_series.plugin_content + for (tag, time_series) in tag_to_time_series.items() + } + for (run, tag_to_time_series) in data_provider_output.items() + } + + def hparams_metadata(self, experiment_id, run_tag_filter=None): """Reads summary metadata for all hparams time series. Args: experiment_id: String, from `plugin_util.experiment_id`. + run_tag_filter: Optional `data.provider.RunTagFilter`, with + the semantics as in `list_tensors`. Returns: A dict `d` such that `d[run][tag]` is a `bytes` value with the summary metadata content for the keyed time series. """ - assert isinstance(experiment_id, str), ( - experiment_id, - type(experiment_id), - ) - return self._deprecated_multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME + return self._convert_plugin_metadata( + self._tb_context.data_provider.list_tensors( + experiment_id, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=run_tag_filter, + ) ) def scalars_metadata(self, experiment_id): @@ -115,12 +123,10 @@ def scalars_metadata(self, experiment_id): A dict `d` such that `d[run][tag]` is a `bytes` value with the summary metadata content for the keyed time series. """ - assert isinstance(experiment_id, str), ( - experiment_id, - type(experiment_id), - ) - return self._deprecated_multiplexer.PluginRunToTagToContent( - scalar_metadata.PLUGIN_NAME + return self._convert_plugin_metadata( + self._tb_context.data_provider.list_scalars( + experiment_id, plugin_name=scalar_metadata.PLUGIN_NAME + ) ) def read_scalars(self, experiment_id, run, tag): @@ -134,11 +140,27 @@ def read_scalars(self, experiment_id, run, tag): Returns: A list of `plugin_event_accumulator.TensorEvent` values. """ - assert isinstance(experiment_id, str), ( + data_provider_output = self._tb_context.data_provider.read_scalars( experiment_id, - type(experiment_id), + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter([run], [tag]), + downsample=(self._tb_context.sampling_hints or {}).get( + scalar_metadata.PLUGIN_NAME, 1000 + ), ) - return self._deprecated_multiplexer.Tensors(run, tag) + data = data_provider_output.get(run, {}).get(tag) + if data is None: + raise KeyError("No scalar data for run=%r, tag=%r" % (run, tag)) + return [ + # TODO(#3425): Change clients to depend on data provider + # APIs natively and remove this post-processing step. + event_accumulator.TensorEvent( + wall_time=e.wall_time, + step=e.step, + tensor_proto=tensor_util.make_tensor_proto(e.value), + ) + for e in data + ] def _find_experiment_tag(self, experiment_id): """Finds the experiment associcated with the metadata.EXPERIMENT_TAG @@ -147,14 +169,19 @@ def _find_experiment_tag(self, experiment_id): Returns: The experiment or None if no such experiment is found. """ - mapping = self.hparams_metadata(experiment_id) - for tag_to_content in mapping.values(): - if metadata.EXPERIMENT_TAG in tag_to_content: - experiment = metadata.parse_experiment_plugin_data( - tag_to_content[metadata.EXPERIMENT_TAG] - ) - return experiment - return None + mapping = self.hparams_metadata( + experiment_id, + run_tag_filter=provider.RunTagFilter( + tags=[metadata.EXPERIMENT_TAG] + ), + ) + if not mapping: + return None + # We expect only one run to have an `EXPERIMENT_TAG`; pick + # arbitrarily. + tag_to_content = next(iter(mapping.values())) + content = next(iter(tag_to_content.values())) + return metadata.parse_experiment_plugin_data(content) def _compute_experiment_from_runs(self, experiment_id): """Computes a minimal Experiment protocol buffer by scanning the diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index 9b963f689a..799b6863a8 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -28,13 +28,16 @@ import tensorflow as tf from google.protobuf import text_format +from tensorboard.backend.event_processing import data_provider from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import plugin_event_multiplexer +from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins import base_plugin from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import backend_context from tensorboard.plugins.hparams import metadata from tensorboard.plugins.hparams import plugin_data_pb2 +from tensorboard.plugins.scalar import metadata as scalars_metadata DATA_TYPE_EXPERIMENT = "experiment" DATA_TYPE_SESSION_START_INFO = "session_start_info" @@ -46,7 +49,9 @@ class BackendContextTest(tf.test.TestCase): maxDiff = None # pylint: disable=invalid-name def setUp(self): - self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) + self._mock_tb_context = base_plugin.TBContext() + # TODO(#3425): Remove mocking or switch to mocking data provider + # APIs directly. self._mock_multiplexer = mock.create_autospec( plugin_event_multiplexer.EventMultiplexer ) @@ -54,48 +59,82 @@ def setUp(self): self._mock_multiplexer.PluginRunToTagToContent.side_effect = ( self._mock_plugin_run_to_tag_to_content ) + self._mock_multiplexer.AllSummaryMetadata.side_effect = ( + self._mock_all_summary_metadata + ) + self._mock_multiplexer.SummaryMetadata.side_effect = ( + self._mock_summary_metadata + ) + self._mock_tb_context.data_provider = data_provider.MultiplexerDataProvider( + self._mock_multiplexer, "/path/to/logs" + ) self.session_1_start_info_ = "" self.session_2_start_info_ = "" self.session_3_start_info_ = "" + def _mock_all_summary_metadata(self): + result = {} + hparams_content = { + "exp/session_1": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_ + ), + }, + "exp/session_2": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_ + ), + }, + "exp/session_3": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_ + ), + }, + } + scalars_content = { + "exp/session_1": {"loss": b"", "accuracy": b""}, + "exp/session_1/eval": {"loss": b"",}, + "exp/session_1/train": {"loss": b"",}, + "exp/session_2": {"loss": b"", "accuracy": b"",}, + "exp/session_2/eval": {"loss": b"",}, + "exp/session_2/train": {"loss": b"",}, + "exp/session_3": {"loss": b"", "accuracy": b"",}, + "exp/session_3/eval": {"loss": b"",}, + "exp/session_3xyz/": {"loss2": b"",}, + } + for (run, tag_to_content) in hparams_content.items(): + result.setdefault(run, {}) + for (tag, content) in tag_to_content.items(): + m = summary_pb2.SummaryMetadata() + m.data_class = summary_pb2.DATA_CLASS_TENSOR + m.plugin_data.plugin_name = metadata.PLUGIN_NAME + m.plugin_data.content = content + result[run][tag] = m + for (run, tag_to_content) in scalars_content.items(): + result.setdefault(run, {}) + for (tag, content) in tag_to_content.items(): + m = summary_pb2.SummaryMetadata() + m.data_class = summary_pb2.DATA_CLASS_SCALAR + m.plugin_data.plugin_name = scalars_metadata.PLUGIN_NAME + m.plugin_data.content = content + result[run][tag] = m + return result + def _mock_plugin_run_to_tag_to_content(self, plugin_name): - if plugin_name == metadata.PLUGIN_NAME: - return { - "exp/session_1": { - metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_ - ), - }, - "exp/session_2": { - metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_ - ), - }, - "exp/session_3": { - metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_ - ), - }, - } - SCALARS = event_accumulator.SCALARS # pylint: disable=invalid-name - if plugin_name == SCALARS: - return { - # We use None as the content here, since the content is not - # used in the test. - "exp/session_1": {"loss": None, "accuracy": None}, - "exp/session_1/eval": {"loss": None,}, - "exp/session_1/train": {"loss": None,}, - "exp/session_2": {"loss": None, "accuracy": None,}, - "exp/session_2/eval": {"loss": None,}, - "exp/session_2/train": {"loss": None,}, - "exp/session_3": {"loss": None, "accuracy": None,}, - "exp/session_3/eval": {"loss": None,}, - "exp/session_3xyz/": {"loss2": None,}, - } - self.fail( - "Unexpected plugin_name '%s' passed to" - " EventMultiplexer.PluginRunToTagToContent" % plugin_name - ) + result = {} + for ( + run, + tag_to_metadata, + ) in self._mock_multiplexer.AllSummaryMetadata().items(): + for (tag, metadata) in tag_to_metadata.items(): + if metadata.plugin_data.plugin_name != plugin_name: + continue + result.setdefault(run, {}) + result[run][tag] = metadata.plugin_data.content + return result + + def _mock_summary_metadata(self, run, tag): + return self._mock_multiplexer.AllSummaryMetadata()[run][tag] def test_experiment_with_experiment_tag(self): experiment = """ @@ -104,14 +143,16 @@ def test_experiment_with_experiment_tag(self): { name: { tag: 'current_temp' } } ] """ - self._mock_multiplexer.PluginRunToTagToContent.side_effect = None - self._mock_multiplexer.PluginRunToTagToContent.return_value = { - "exp": { - metadata.EXPERIMENT_TAG: self._serialized_plugin_data( - DATA_TYPE_EXPERIMENT, experiment - ) - } - } + run = "exp" + tag = metadata.EXPERIMENT_TAG + m = summary_pb2.SummaryMetadata() + m.data_class = summary_pb2.DATA_CLASS_TENSOR + m.plugin_data.plugin_name = metadata.PLUGIN_NAME + m.plugin_data.content = self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ) + self._mock_multiplexer.AllSummaryMetadata.side_effect = None + self._mock_multiplexer.AllSummaryMetadata.return_value = {run: {tag: m}} ctxt = backend_context.Context(self._mock_tb_context) self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123")) diff --git a/tensorboard/plugins/hparams/list_session_groups_test.py b/tensorboard/plugins/hparams/list_session_groups_test.py index 8395c96448..313f413661 100644 --- a/tensorboard/plugins/hparams/list_session_groups_test.py +++ b/tensorboard/plugins/hparams/list_session_groups_test.py @@ -29,14 +29,17 @@ import mock # pylint: disable=unused-import from google.protobuf import text_format +from tensorboard.backend.event_processing import data_provider from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import plugin_event_multiplexer +from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins import base_plugin from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import backend_context from tensorboard.plugins.hparams import list_session_groups from tensorboard.plugins.hparams import metadata from tensorboard.plugins.hparams import plugin_data_pb2 +from tensorboard.plugins.scalar import metadata as scalars_metadata DATA_TYPE_EXPERIMENT = "experiment" @@ -53,12 +56,30 @@ class ListSessionGroupsTest(tf.test.TestCase): maxDiff = None # pylint: disable=invalid-name def setUp(self): - self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) + self._mock_tb_context = base_plugin.TBContext() + # TODO(#3425): Remove mocking or switch to mocking data provider + # APIs directly. self._mock_multiplexer = mock.create_autospec( plugin_event_multiplexer.EventMultiplexer ) self._mock_tb_context.multiplexer = self._mock_multiplexer - self._mock_multiplexer.PluginRunToTagToContent.return_value = { + self._mock_multiplexer.PluginRunToTagToContent.side_effect = ( + self._mock_plugin_run_to_tag_to_content + ) + self._mock_multiplexer.AllSummaryMetadata.side_effect = ( + self._mock_all_summary_metadata + ) + self._mock_multiplexer.SummaryMetadata.side_effect = ( + self._mock_summary_metadata + ) + self._mock_multiplexer.Tensors.side_effect = self._mock_tensors + self._mock_tb_context.data_provider = data_provider.MultiplexerDataProvider( + self._mock_multiplexer, "/path/to/logs" + ) + + def _mock_all_summary_metadata(self): + result = {} + hparams_content = { "": { metadata.EXPERIMENT_TAG: self._serialized_plugin_data( DATA_TYPE_EXPERIMENT, @@ -200,12 +221,60 @@ def setUp(self): ), }, } - self._mock_multiplexer.Tensors.side_effect = self._mock_tensors + scalars_content = { + "session_1": { + "current_temp": b"", + "delta_temp": b"", + "optional_metric": b"", + }, + "session_2": {"current_temp": b"", "delta_temp": b""}, + "session_3": {"current_temp": b"", "delta_temp": b""}, + "session_4": {"current_temp": b"", "delta_temp": b""}, + "session_5": {"current_temp": b"", "delta_temp": b""}, + } + for (run, tag_to_content) in hparams_content.items(): + result.setdefault(run, {}) + for (tag, content) in tag_to_content.items(): + m = summary_pb2.SummaryMetadata() + m.data_class = summary_pb2.DATA_CLASS_TENSOR + m.plugin_data.plugin_name = metadata.PLUGIN_NAME + m.plugin_data.content = content + result[run][tag] = m + for (run, tag_to_content) in scalars_content.items(): + result.setdefault(run, {}) + for (tag, content) in tag_to_content.items(): + m = summary_pb2.SummaryMetadata() + m.data_class = summary_pb2.DATA_CLASS_SCALAR + m.plugin_data.plugin_name = scalars_metadata.PLUGIN_NAME + m.plugin_data.content = content + result[run][tag] = m + return result + + def _mock_plugin_run_to_tag_to_content(self, plugin_name): + result = {} + for (run, tag_to_metadata) in self._mock_all_summary_metadata().items(): + for (tag, metadata) in tag_to_metadata.items(): + if metadata.plugin_data.plugin_name != plugin_name: + continue + result.setdefault(run, {}) + result[run][tag] = metadata.plugin_data.content + return result + + def _mock_summary_metadata(self, run, tag): + return self._mock_all_summary_metadata()[run][tag] # A mock version of EventMultiplexer.Tensors def _mock_tensors(self, run, tag): + hparams_time_series = [ + TensorEvent( + wall_time=123.75, step=0, tensor_proto=metadata.NULL_TENSOR + ) + ] result_dict = { + "": {metadata.EXPERIMENT_TAG: hparams_time_series[:],}, "session_1": { + metadata.SESSION_START_INFO_TAG: hparams_time_series[:], + metadata.SESSION_END_INFO_TAG: hparams_time_series[:], "current_temp": [ TensorEvent( wall_time=1, @@ -239,6 +308,8 @@ def _mock_tensors(self, run, tag): ], }, "session_2": { + metadata.SESSION_START_INFO_TAG: hparams_time_series[:], + metadata.SESSION_END_INFO_TAG: hparams_time_series[:], "current_temp": [ TensorEvent( wall_time=1, @@ -260,6 +331,8 @@ def _mock_tensors(self, run, tag): ], }, "session_3": { + metadata.SESSION_START_INFO_TAG: hparams_time_series[:], + metadata.SESSION_END_INFO_TAG: hparams_time_series[:], "current_temp": [ TensorEvent( wall_time=1, @@ -281,6 +354,8 @@ def _mock_tensors(self, run, tag): ], }, "session_4": { + metadata.SESSION_START_INFO_TAG: hparams_time_series[:], + metadata.SESSION_END_INFO_TAG: hparams_time_series[:], "current_temp": [ TensorEvent( wall_time=1, @@ -302,6 +377,8 @@ def _mock_tensors(self, run, tag): ], }, "session_5": { + metadata.SESSION_START_INFO_TAG: hparams_time_series[:], + metadata.SESSION_END_INFO_TAG: hparams_time_series[:], "current_temp": [ TensorEvent( wall_time=1,