-
Notifications
You must be signed in to change notification settings - Fork 1.7k
hparams: read from generic data APIs #3419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
311b3b5
00bb2a0
68b80ec
7d1481e
c33b665
eaaad6d
b06127d
de0aa14
412ddfa
8332731
7195a4f
b5062f9
c4c8b13
0d99676
3e555dc
1cc53f7
7335259
8813cc7
43453e9
022bb0d
d3a59f6
6f4e7cd
8873bc8
946f408
7f938c5
c1b9af5
bbd0841
7ff57f9
77165dd
61c77dc
aa30fe0
e6341f9
9a5d325
20df297
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,10 +25,15 @@ | |
|
|
||
| import six | ||
|
|
||
| from tensorboard.backend.event_processing import ( | ||
| plugin_event_accumulator as event_accumulator, | ||
| ) | ||
| from tensorboard.data import provider | ||
| from tensorboard.plugins.hparams import api_pb2 | ||
| from tensorboard.plugins.hparams import metadata | ||
| from google.protobuf import json_format | ||
| from tensorboard.plugins.scalar import metadata as scalar_metadata | ||
| from tensorboard.util import tensor_util | ||
|
|
||
|
|
||
| class Context(object): | ||
|
|
@@ -75,34 +80,37 @@ def experiment(self, experiment_id): | |
| return self._compute_experiment_from_runs(experiment_id) | ||
| return experiment | ||
|
|
||
| @property | ||
| def _deprecated_multiplexer(self): | ||
| return self._tb_context.multiplexer | ||
|
|
||
| @property | ||
| def multiplexer(self): | ||
| raise NotImplementedError("Do not read `Context.multiplexer` directly") | ||
|
|
||
| @property | ||
| def tb_context(self): | ||
| return self._tb_context | ||
|
|
||
| def hparams_metadata(self, experiment_id): | ||
| def _convert_plugin_metadata(self, data_provider_output): | ||
| return { | ||
| run: { | ||
| tag: time_series.plugin_content | ||
| for (tag, time_series) in tag_to_time_series.items() | ||
| } | ||
| for (run, tag_to_time_series) in data_provider_output.items() | ||
| } | ||
|
|
||
| def hparams_metadata(self, experiment_id, run_tag_filter=None): | ||
| """Reads summary metadata for all hparams time series. | ||
|
|
||
| Args: | ||
| experiment_id: String, from `plugin_util.experiment_id`. | ||
| run_tag_filter: Optional `data.provider.RunTagFilter`, with | ||
| the semantics as in `list_tensors`. | ||
|
|
||
| Returns: | ||
| A dict `d` such that `d[run][tag]` is a `bytes` value with the | ||
| summary metadata content for the keyed time series. | ||
| """ | ||
| assert isinstance(experiment_id, str), ( | ||
| experiment_id, | ||
| type(experiment_id), | ||
| ) | ||
| return self._deprecated_multiplexer.PluginRunToTagToContent( | ||
| metadata.PLUGIN_NAME | ||
| return self._convert_plugin_metadata( | ||
| self._tb_context.data_provider.list_tensors( | ||
| experiment_id, | ||
| plugin_name=metadata.PLUGIN_NAME, | ||
| run_tag_filter=run_tag_filter, | ||
| ) | ||
| ) | ||
|
|
||
| def scalars_metadata(self, experiment_id): | ||
|
|
@@ -115,12 +123,10 @@ def scalars_metadata(self, experiment_id): | |
| A dict `d` such that `d[run][tag]` is a `bytes` value with the | ||
| summary metadata content for the keyed time series. | ||
| """ | ||
| assert isinstance(experiment_id, str), ( | ||
| experiment_id, | ||
| type(experiment_id), | ||
| ) | ||
| return self._deprecated_multiplexer.PluginRunToTagToContent( | ||
| scalar_metadata.PLUGIN_NAME | ||
| return self._convert_plugin_metadata( | ||
| self._tb_context.data_provider.list_scalars( | ||
| experiment_id, plugin_name=scalar_metadata.PLUGIN_NAME | ||
| ) | ||
| ) | ||
|
|
||
| def read_scalars(self, experiment_id, run, tag): | ||
|
|
@@ -134,11 +140,27 @@ def read_scalars(self, experiment_id, run, tag): | |
| Returns: | ||
| A list of `plugin_event_accumulator.TensorEvent` values. | ||
| """ | ||
| assert isinstance(experiment_id, str), ( | ||
| data_provider_output = self._tb_context.data_provider.read_scalars( | ||
| experiment_id, | ||
| type(experiment_id), | ||
| plugin_name=scalar_metadata.PLUGIN_NAME, | ||
| run_tag_filter=provider.RunTagFilter([run], [tag]), | ||
| downsample=(self._tb_context.sampling_hints or {}).get( | ||
| scalar_metadata.PLUGIN_NAME, 1000 | ||
| ), | ||
| ) | ||
| return self._deprecated_multiplexer.Tensors(run, tag) | ||
| data = data_provider_output.get(run, {}).get(tag) | ||
| if data is None: | ||
| raise KeyError("No scalar data for run=%r, tag=%r" % (run, tag)) | ||
| return [ | ||
| # TODO(#3425): Change clients to depend on data provider | ||
| # APIs natively and remove this post-processing step. | ||
| event_accumulator.TensorEvent( | ||
| wall_time=e.wall_time, | ||
| step=e.step, | ||
| tensor_proto=tensor_util.make_tensor_proto(e.value), | ||
| ) | ||
| for e in data | ||
| ] | ||
|
|
||
| def _find_experiment_tag(self, experiment_id): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above in this PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point; thanks. Added |
||
| """Finds the experiment associcated with the metadata.EXPERIMENT_TAG | ||
|
|
@@ -147,14 +169,19 @@ def _find_experiment_tag(self, experiment_id): | |
| Returns: | ||
| The experiment or None if no such experiment is found. | ||
| """ | ||
| mapping = self.hparams_metadata(experiment_id) | ||
| for tag_to_content in mapping.values(): | ||
| if metadata.EXPERIMENT_TAG in tag_to_content: | ||
| experiment = metadata.parse_experiment_plugin_data( | ||
| tag_to_content[metadata.EXPERIMENT_TAG] | ||
| ) | ||
| return experiment | ||
| return None | ||
| mapping = self.hparams_metadata( | ||
| experiment_id, | ||
| run_tag_filter=provider.RunTagFilter( | ||
| tags=[metadata.EXPERIMENT_TAG] | ||
| ), | ||
| ) | ||
| if not mapping: | ||
| return None | ||
| # We expect only one run to have an `EXPERIMENT_TAG`; pick | ||
| # arbitrarily. | ||
| tag_to_content = next(iter(mapping.values())) | ||
| content = next(iter(tag_to_content.values())) | ||
| return metadata.parse_experiment_plugin_data(content) | ||
|
|
||
| def _compute_experiment_from_runs(self, experiment_id): | ||
| """Computes a minimal Experiment protocol buffer by scanning the | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,13 +28,16 @@ | |
| import tensorflow as tf | ||
|
|
||
| from google.protobuf import text_format | ||
| from tensorboard.backend.event_processing import data_provider | ||
| from tensorboard.backend.event_processing import event_accumulator | ||
| from tensorboard.backend.event_processing import plugin_event_multiplexer | ||
| from tensorboard.compat.proto import summary_pb2 | ||
| from tensorboard.plugins import base_plugin | ||
| from tensorboard.plugins.hparams import api_pb2 | ||
| from tensorboard.plugins.hparams import backend_context | ||
| from tensorboard.plugins.hparams import metadata | ||
| from tensorboard.plugins.hparams import plugin_data_pb2 | ||
| from tensorboard.plugins.scalar import metadata as scalars_metadata | ||
|
|
||
| DATA_TYPE_EXPERIMENT = "experiment" | ||
| DATA_TYPE_SESSION_START_INFO = "session_start_info" | ||
|
|
@@ -46,56 +49,92 @@ class BackendContextTest(tf.test.TestCase): | |
| maxDiff = None # pylint: disable=invalid-name | ||
|
|
||
| def setUp(self): | ||
| self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) | ||
| self._mock_tb_context = base_plugin.TBContext() | ||
| # TODO(#3425): Remove mocking or switch to mocking data provider | ||
| # APIs directly. | ||
| self._mock_multiplexer = mock.create_autospec( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto here re: having a TODO to clean up the test code here so that if it's going to use mocks it at least mocks the DataProvider APIs directly rather than mocking EventMultiplexer and then wrapping that in MultiplexerDataProvider.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do. |
||
| plugin_event_multiplexer.EventMultiplexer | ||
| ) | ||
| self._mock_tb_context.multiplexer = self._mock_multiplexer | ||
| self._mock_multiplexer.PluginRunToTagToContent.side_effect = ( | ||
| self._mock_plugin_run_to_tag_to_content | ||
| ) | ||
| self._mock_multiplexer.AllSummaryMetadata.side_effect = ( | ||
| self._mock_all_summary_metadata | ||
| ) | ||
| self._mock_multiplexer.SummaryMetadata.side_effect = ( | ||
| self._mock_summary_metadata | ||
| ) | ||
| self._mock_tb_context.data_provider = data_provider.MultiplexerDataProvider( | ||
| self._mock_multiplexer, "/path/to/logs" | ||
| ) | ||
| self.session_1_start_info_ = "" | ||
| self.session_2_start_info_ = "" | ||
| self.session_3_start_info_ = "" | ||
|
|
||
| def _mock_all_summary_metadata(self): | ||
| result = {} | ||
| hparams_content = { | ||
| "exp/session_1": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_ | ||
| ), | ||
| }, | ||
| "exp/session_2": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_ | ||
| ), | ||
| }, | ||
| "exp/session_3": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_ | ||
| ), | ||
| }, | ||
| } | ||
| scalars_content = { | ||
| "exp/session_1": {"loss": b"", "accuracy": b""}, | ||
| "exp/session_1/eval": {"loss": b"",}, | ||
| "exp/session_1/train": {"loss": b"",}, | ||
| "exp/session_2": {"loss": b"", "accuracy": b"",}, | ||
| "exp/session_2/eval": {"loss": b"",}, | ||
| "exp/session_2/train": {"loss": b"",}, | ||
| "exp/session_3": {"loss": b"", "accuracy": b"",}, | ||
| "exp/session_3/eval": {"loss": b"",}, | ||
| "exp/session_3xyz/": {"loss2": b"",}, | ||
| } | ||
| for (run, tag_to_content) in hparams_content.items(): | ||
| result.setdefault(run, {}) | ||
| for (tag, content) in tag_to_content.items(): | ||
| m = summary_pb2.SummaryMetadata() | ||
| m.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| m.plugin_data.plugin_name = metadata.PLUGIN_NAME | ||
| m.plugin_data.content = content | ||
| result[run][tag] = m | ||
| for (run, tag_to_content) in scalars_content.items(): | ||
| result.setdefault(run, {}) | ||
| for (tag, content) in tag_to_content.items(): | ||
| m = summary_pb2.SummaryMetadata() | ||
| m.data_class = summary_pb2.DATA_CLASS_SCALAR | ||
| m.plugin_data.plugin_name = scalars_metadata.PLUGIN_NAME | ||
| m.plugin_data.content = content | ||
| result[run][tag] = m | ||
| return result | ||
|
|
||
| def _mock_plugin_run_to_tag_to_content(self, plugin_name): | ||
| if plugin_name == metadata.PLUGIN_NAME: | ||
| return { | ||
| "exp/session_1": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_ | ||
| ), | ||
| }, | ||
| "exp/session_2": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_ | ||
| ), | ||
| }, | ||
| "exp/session_3": { | ||
| metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_ | ||
| ), | ||
| }, | ||
| } | ||
| SCALARS = event_accumulator.SCALARS # pylint: disable=invalid-name | ||
| if plugin_name == SCALARS: | ||
| return { | ||
| # We use None as the content here, since the content is not | ||
| # used in the test. | ||
| "exp/session_1": {"loss": None, "accuracy": None}, | ||
| "exp/session_1/eval": {"loss": None,}, | ||
| "exp/session_1/train": {"loss": None,}, | ||
| "exp/session_2": {"loss": None, "accuracy": None,}, | ||
| "exp/session_2/eval": {"loss": None,}, | ||
| "exp/session_2/train": {"loss": None,}, | ||
| "exp/session_3": {"loss": None, "accuracy": None,}, | ||
| "exp/session_3/eval": {"loss": None,}, | ||
| "exp/session_3xyz/": {"loss2": None,}, | ||
| } | ||
| self.fail( | ||
| "Unexpected plugin_name '%s' passed to" | ||
| " EventMultiplexer.PluginRunToTagToContent" % plugin_name | ||
| ) | ||
| result = {} | ||
| for ( | ||
| run, | ||
| tag_to_metadata, | ||
| ) in self._mock_multiplexer.AllSummaryMetadata().items(): | ||
| for (tag, metadata) in tag_to_metadata.items(): | ||
| if metadata.plugin_data.plugin_name != plugin_name: | ||
| continue | ||
| result.setdefault(run, {}) | ||
| result[run][tag] = metadata.plugin_data.content | ||
| return result | ||
|
|
||
| def _mock_summary_metadata(self, run, tag): | ||
| return self._mock_multiplexer.AllSummaryMetadata()[run][tag] | ||
|
|
||
| def test_experiment_with_experiment_tag(self): | ||
| experiment = """ | ||
|
|
@@ -104,14 +143,16 @@ def test_experiment_with_experiment_tag(self): | |
| { name: { tag: 'current_temp' } } | ||
| ] | ||
| """ | ||
| self._mock_multiplexer.PluginRunToTagToContent.side_effect = None | ||
| self._mock_multiplexer.PluginRunToTagToContent.return_value = { | ||
| "exp": { | ||
| metadata.EXPERIMENT_TAG: self._serialized_plugin_data( | ||
| DATA_TYPE_EXPERIMENT, experiment | ||
| ) | ||
| } | ||
| } | ||
| run = "exp" | ||
| tag = metadata.EXPERIMENT_TAG | ||
| m = summary_pb2.SummaryMetadata() | ||
| m.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| m.plugin_data.plugin_name = metadata.PLUGIN_NAME | ||
| m.plugin_data.content = self._serialized_plugin_data( | ||
| DATA_TYPE_EXPERIMENT, experiment | ||
| ) | ||
| self._mock_multiplexer.AllSummaryMetadata.side_effect = None | ||
| self._mock_multiplexer.AllSummaryMetadata.return_value = {run: {tag: m}} | ||
| ctxt = backend_context.Context(self._mock_tb_context) | ||
| self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123")) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe file a cleanup github issue to remove the vestigial bits of the multiplexer API from the hparams plugin and/or from plugins generally after they're migrated? Would be nice to at least leave a TODO here, since it's kind of weird to still have TensorEvent used for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely. I was planning to do so right after these changes
(switch these functions to directly return the data provider results and
simultaneously migrate callers), but wanted to get these in first to
unblock some internal work. Filed #3425.
For what it’s worth, this is not a performance bottleneck. The
post-processing takes on the order of 1–5 microseconds. The call to
data_provider.read_scalarsis slow because it callslist_scalarsinternally, and
list_scalarsitself iterates over all the tensors togenerate metadata that we just throw away. It currently takes on the
order of 10 milliseconds, so that’s not good. And the hparams usage of
read_scalarsis slow, too, because we call it once for each of theO(|metrics| × |sessions|) time series instead of passing a single
RunTagFilterthat represents the cross product, so that’s not good,either. The data-munging here is negligible in comparison.
(I will see if I can speed up
_list_scalarsbefore merging this PR.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I wasn't concerned at all about performance. Just seemed like a vestigial weirdness to leave behind - thanks for filing the issue!