Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
311b3b5
backend: always provide data provider on context
wchargin Mar 23, 2020
00bb2a0
hparams: follow standard `is_active` pattern
wchargin Mar 23, 2020
68b80ec
hparams: encapsulate direct multiplexer access
wchargin Mar 23, 2020
7d1481e
hparams: thread experiment ID through data access
wchargin Mar 23, 2020
c33b665
hparams: read from generic data APIs
wchargin Mar 23, 2020
eaaad6d
data: optimize `read_scalars` by skipping scans
wchargin Mar 26, 2020
b06127d
[update diffbase]
wchargin Mar 26, 2020
de0aa14
[update diffbase]
wchargin Mar 26, 2020
412ddfa
[update diffbase]
wchargin Mar 26, 2020
8332731
[update patch]
wchargin Mar 26, 2020
7195a4f
data: add tests for blob sequence handling
wchargin Mar 26, 2020
b5062f9
[update diffbase]
wchargin Mar 26, 2020
c4c8b13
[update patch]
wchargin Mar 26, 2020
0d99676
[update diffbase]
wchargin Mar 26, 2020
3e555dc
[update diffbase]
wchargin Mar 26, 2020
1cc53f7
[update diffbase]
wchargin Mar 26, 2020
7335259
[update patch]
wchargin Mar 26, 2020
8813cc7
[update diffbase]
wchargin Mar 26, 2020
43453e9
[update diffbase]
wchargin Mar 26, 2020
022bb0d
[update diffbase]
wchargin Mar 26, 2020
d3a59f6
[update diffbase]
wchargin Mar 26, 2020
6f4e7cd
[update patch]
wchargin Mar 26, 2020
8873bc8
[update diffbase]
wchargin Mar 26, 2020
946f408
[update diffbase]
wchargin Mar 26, 2020
7f938c5
[update diffbase]
wchargin Mar 26, 2020
c1b9af5
[update patch]
wchargin Mar 26, 2020
bbd0841
[update patch]
wchargin Mar 26, 2020
7ff57f9
[update diffbase]
wchargin Mar 26, 2020
77165dd
[update diffbase]
wchargin Mar 26, 2020
61c77dc
[update diffbase]
wchargin Mar 26, 2020
aa30fe0
[update patch]
wchargin Mar 26, 2020
e6341f9
[update diffbase]
wchargin Mar 26, 2020
9a5d325
[update patch]
wchargin Mar 26, 2020
20df297
[update diffbase]
wchargin Mar 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tensorboard/plugins/hparams/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ py_library(
":protos_all_py_pb2",
"//tensorboard:plugin_util",
"//tensorboard/backend:http_util",
"//tensorboard/data:provider",
"//tensorboard/plugins:base_plugin",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/scalar:scalars_plugin",
Expand Down Expand Up @@ -72,14 +73,18 @@ py_test(
py_test(
name = "list_session_groups_test",
size = "small",
timeout = "moderate",
srcs = [
"list_session_groups_test.py",
],
deps = [
":hparams_plugin",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_accumulator",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/scalar:metadata",
"@org_pythonhosted_mock",
],
)
Expand All @@ -106,8 +111,11 @@ py_test(
deps = [
":hparams_plugin",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_accumulator",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/scalar:metadata",
"@org_pythonhosted_mock",
],
)
Expand Down
91 changes: 59 additions & 32 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@

import six

from tensorboard.backend.event_processing import (
plugin_event_accumulator as event_accumulator,
)
from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.plugins.scalar import metadata as scalar_metadata
from tensorboard.util import tensor_util


class Context(object):
Expand Down Expand Up @@ -75,34 +80,37 @@ def experiment(self, experiment_id):
return self._compute_experiment_from_runs(experiment_id)
return experiment

@property
def _deprecated_multiplexer(self):
return self._tb_context.multiplexer

@property
def multiplexer(self):
raise NotImplementedError("Do not read `Context.multiplexer` directly")

@property
def tb_context(self):
return self._tb_context

def hparams_metadata(self, experiment_id):
def _convert_plugin_metadata(self, data_provider_output):
return {
run: {
tag: time_series.plugin_content
for (tag, time_series) in tag_to_time_series.items()
}
for (run, tag_to_time_series) in data_provider_output.items()
}

def hparams_metadata(self, experiment_id, run_tag_filter=None):
"""Reads summary metadata for all hparams time series.

Args:
experiment_id: String, from `plugin_util.experiment_id`.
run_tag_filter: Optional `data.provider.RunTagFilter`, with
the semantics as in `list_tensors`.

Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_tensors(
experiment_id,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)
)

def scalars_metadata(self, experiment_id):
Expand All @@ -115,12 +123,10 @@ def scalars_metadata(self, experiment_id):
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
scalar_metadata.PLUGIN_NAME
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_scalars(
experiment_id, plugin_name=scalar_metadata.PLUGIN_NAME
)
)

def read_scalars(self, experiment_id, run, tag):
Expand All @@ -134,11 +140,27 @@ def read_scalars(self, experiment_id, run, tag):
Returns:
A list of `plugin_event_accumulator.TensorEvent` values.
"""
assert isinstance(experiment_id, str), (
data_provider_output = self._tb_context.data_provider.read_scalars(
experiment_id,
type(experiment_id),
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter([run], [tag]),
downsample=(self._tb_context.sampling_hints or {}).get(
scalar_metadata.PLUGIN_NAME, 1000
),
)
return self._deprecated_multiplexer.Tensors(run, tag)
data = data_provider_output.get(run, {}).get(tag)
if data is None:
raise KeyError("No scalar data for run=%r, tag=%r" % (run, tag))
return [
# TODO(#3425): Change clients to depend on data provider
# APIs natively and remove this post-processing step.
event_accumulator.TensorEvent(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe file a cleanup github issue to remove the vestigial bits of the multiplexer API from the hparams plugin and/or from plugins generally after they're migrated? Would be nice to at least leave a TODO here, since it's kind of weird to still have TensorEvent used for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. I was planning to do so right after these changes
(switch these functions to directly return the data provider results and
simultaneously migrate callers), but wanted to get these in first to
unblock some internal work. Filed #3425.

For what it’s worth, this is not a performance bottleneck. The
post-processing takes on the order of 1–5 microseconds. The call to
data_provider.read_scalars is slow because it calls list_scalars
internally, and list_scalars itself iterates over all the tensors to
generate metadata that we just throw away. It currently takes on the
order of 10 milliseconds, so that’s not good. And the hparams usage of
read_scalars is slow, too, because we call it once for each of the
O(|metrics| × |sessions|) time series instead of passing a single
RunTagFilter that represents the cross product, so that’s not good,
either. The data-munging here is negligible in comparison.

(I will see if I can speed up _list_scalars before merging this PR.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't concerned at all about performance. Just seemed like a vestigial weirdness to leave behind - thanks for filing the issue!

wall_time=e.wall_time,
step=e.step,
tensor_proto=tensor_util.make_tensor_proto(e.value),
)
for e in data
]

def _find_experiment_tag(self, experiment_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above in this PR hparams_metadata() will resolve now to an unfiltered list_tensors() call, but in this particular callsite within _find_experiment_tag() we only want a single tag, so IMO it'd be nice if we just call the data provider with an appropriate tag filter rather than postfiltering all tags. (Or optionally I guess hparams_metadata() could learn a tags argument, if you prefer.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point; thanks. Added run_tag_filter argument, which is a nice
simplification.

"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
Expand All @@ -147,14 +169,19 @@ def _find_experiment_tag(self, experiment_id):
Returns:
The experiment or None if no such experiment is found.
"""
mapping = self.hparams_metadata(experiment_id)
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
experiment = metadata.parse_experiment_plugin_data(
tag_to_content[metadata.EXPERIMENT_TAG]
)
return experiment
return None
mapping = self.hparams_metadata(
experiment_id,
run_tag_filter=provider.RunTagFilter(
tags=[metadata.EXPERIMENT_TAG]
),
)
if not mapping:
return None
# We expect only one run to have an `EXPERIMENT_TAG`; pick
# arbitrarily.
tag_to_content = next(iter(mapping.values()))
content = next(iter(tag_to_content.values()))
return metadata.parse_experiment_plugin_data(content)

def _compute_experiment_from_runs(self, experiment_id):
"""Computes a minimal Experiment protocol buffer by scanning the
Expand Down
133 changes: 87 additions & 46 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import tensorflow as tf

from google.protobuf import text_format
from tensorboard.backend.event_processing import data_provider
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing import plugin_event_multiplexer
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins import base_plugin
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import backend_context
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import plugin_data_pb2
from tensorboard.plugins.scalar import metadata as scalars_metadata

DATA_TYPE_EXPERIMENT = "experiment"
DATA_TYPE_SESSION_START_INFO = "session_start_info"
Expand All @@ -46,56 +49,92 @@ class BackendContextTest(tf.test.TestCase):
maxDiff = None # pylint: disable=invalid-name

def setUp(self):
self._mock_tb_context = mock.create_autospec(base_plugin.TBContext)
self._mock_tb_context = base_plugin.TBContext()
# TODO(#3425): Remove mocking or switch to mocking data provider
# APIs directly.
self._mock_multiplexer = mock.create_autospec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here re: having a TODO to clean up the test code here so that if it's going to use mocks it at least mocks the DataProvider APIs directly rather than mocking EventMultiplexer and then wrapping that in MultiplexerDataProvider.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

plugin_event_multiplexer.EventMultiplexer
)
self._mock_tb_context.multiplexer = self._mock_multiplexer
self._mock_multiplexer.PluginRunToTagToContent.side_effect = (
self._mock_plugin_run_to_tag_to_content
)
self._mock_multiplexer.AllSummaryMetadata.side_effect = (
self._mock_all_summary_metadata
)
self._mock_multiplexer.SummaryMetadata.side_effect = (
self._mock_summary_metadata
)
self._mock_tb_context.data_provider = data_provider.MultiplexerDataProvider(
self._mock_multiplexer, "/path/to/logs"
)
self.session_1_start_info_ = ""
self.session_2_start_info_ = ""
self.session_3_start_info_ = ""

def _mock_all_summary_metadata(self):
result = {}
hparams_content = {
"exp/session_1": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_
),
},
"exp/session_2": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_
),
},
"exp/session_3": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_
),
},
}
scalars_content = {
"exp/session_1": {"loss": b"", "accuracy": b""},
"exp/session_1/eval": {"loss": b"",},
"exp/session_1/train": {"loss": b"",},
"exp/session_2": {"loss": b"", "accuracy": b"",},
"exp/session_2/eval": {"loss": b"",},
"exp/session_2/train": {"loss": b"",},
"exp/session_3": {"loss": b"", "accuracy": b"",},
"exp/session_3/eval": {"loss": b"",},
"exp/session_3xyz/": {"loss2": b"",},
}
for (run, tag_to_content) in hparams_content.items():
result.setdefault(run, {})
for (tag, content) in tag_to_content.items():
m = summary_pb2.SummaryMetadata()
m.data_class = summary_pb2.DATA_CLASS_TENSOR
m.plugin_data.plugin_name = metadata.PLUGIN_NAME
m.plugin_data.content = content
result[run][tag] = m
for (run, tag_to_content) in scalars_content.items():
result.setdefault(run, {})
for (tag, content) in tag_to_content.items():
m = summary_pb2.SummaryMetadata()
m.data_class = summary_pb2.DATA_CLASS_SCALAR
m.plugin_data.plugin_name = scalars_metadata.PLUGIN_NAME
m.plugin_data.content = content
result[run][tag] = m
return result

def _mock_plugin_run_to_tag_to_content(self, plugin_name):
if plugin_name == metadata.PLUGIN_NAME:
return {
"exp/session_1": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_
),
},
"exp/session_2": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_
),
},
"exp/session_3": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_
),
},
}
SCALARS = event_accumulator.SCALARS # pylint: disable=invalid-name
if plugin_name == SCALARS:
return {
# We use None as the content here, since the content is not
# used in the test.
"exp/session_1": {"loss": None, "accuracy": None},
"exp/session_1/eval": {"loss": None,},
"exp/session_1/train": {"loss": None,},
"exp/session_2": {"loss": None, "accuracy": None,},
"exp/session_2/eval": {"loss": None,},
"exp/session_2/train": {"loss": None,},
"exp/session_3": {"loss": None, "accuracy": None,},
"exp/session_3/eval": {"loss": None,},
"exp/session_3xyz/": {"loss2": None,},
}
self.fail(
"Unexpected plugin_name '%s' passed to"
" EventMultiplexer.PluginRunToTagToContent" % plugin_name
)
result = {}
for (
run,
tag_to_metadata,
) in self._mock_multiplexer.AllSummaryMetadata().items():
for (tag, metadata) in tag_to_metadata.items():
if metadata.plugin_data.plugin_name != plugin_name:
continue
result.setdefault(run, {})
result[run][tag] = metadata.plugin_data.content
return result

def _mock_summary_metadata(self, run, tag):
return self._mock_multiplexer.AllSummaryMetadata()[run][tag]

def test_experiment_with_experiment_tag(self):
experiment = """
Expand All @@ -104,14 +143,16 @@ def test_experiment_with_experiment_tag(self):
{ name: { tag: 'current_temp' } }
]
"""
self._mock_multiplexer.PluginRunToTagToContent.side_effect = None
self._mock_multiplexer.PluginRunToTagToContent.return_value = {
"exp": {
metadata.EXPERIMENT_TAG: self._serialized_plugin_data(
DATA_TYPE_EXPERIMENT, experiment
)
}
}
run = "exp"
tag = metadata.EXPERIMENT_TAG
m = summary_pb2.SummaryMetadata()
m.data_class = summary_pb2.DATA_CLASS_TENSOR
m.plugin_data.plugin_name = metadata.PLUGIN_NAME
m.plugin_data.content = self._serialized_plugin_data(
DATA_TYPE_EXPERIMENT, experiment
)
self._mock_multiplexer.AllSummaryMetadata.side_effect = None
self._mock_multiplexer.AllSummaryMetadata.return_value = {run: {tag: m}}
ctxt = backend_context.Context(self._mock_tb_context)
self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123"))

Expand Down
Loading