Skip to content
3 changes: 3 additions & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ py_test(
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/histogram:summary_v2",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/scalar:summary_v2",
"//tensorboard/summary:summary_v1",
"//tensorboard/uploader/proto:protos_all_py_pb2",
Expand Down Expand Up @@ -193,6 +194,7 @@ py_library(
deps = [
"//tensorboard:expect_requests_installed",
"//tensorboard:version",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
Expand All @@ -208,6 +210,7 @@ py_test(
"//tensorboard:expect_futures_installed",
"//tensorboard:test",
"//tensorboard:version",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"@org_pocoo_werkzeug",
],
Expand Down
23 changes: 23 additions & 0 deletions tensorboard/uploader/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests

from tensorboard import version
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.uploader.proto import server_info_pb2


Expand Down Expand Up @@ -111,6 +112,28 @@ def experiment_url(server_info, experiment_id):
return url_format.template.replace(url_format.id_placeholder, experiment_id)


def allowed_plugins(server_info):
"""Determines which plugins may upload data.

This pulls from the `plugin_control` on the `server_info` when that
submessage is set, else falls back to a default.

Args:
server_info: A `server_info_pb2.ServerInfoResponse` message.

Returns:
A `frozenset` of plugin names.
"""
if server_info.HasField("plugin_control"):
return frozenset(server_info.plugin_control.allowed_plugins)
else:
# Old server: gracefully degrade to scalars only, which have
# been supported since launch. TODO(@wchargin): Promote this
# branch to an error once we're confident that we won't roll
# back to old server versions.
return frozenset((scalars_metadata.PLUGIN_NAME,))


class CommunicationError(RuntimeError):
"""Raised upon failure to communicate with the server."""

Expand Down
30 changes: 30 additions & 0 deletions tensorboard/uploader/server_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from tensorboard import test as tb_test
from tensorboard import version
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.uploader import server_info
from tensorboard.uploader.proto import server_info_pb2

Expand Down Expand Up @@ -163,6 +164,35 @@ def test(self):
self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123")


class AllowedPluginsTest(tb_test.TestCase):
"""Tests for `allowed_plugins`."""

def test_old_server_no_plugins(self):
info = server_info_pb2.ServerInfoResponse()
actual = server_info.allowed_plugins(info)
self.assertEqual(actual, frozenset([scalars_metadata.PLUGIN_NAME]))

def test_provided_but_no_plugins(self):
info = server_info_pb2.ServerInfoResponse()
info.plugin_control.SetInParent()
actual = server_info.allowed_plugins(info)
self.assertEqual(actual, frozenset([]))

def test_scalars_only(self):
info = server_info_pb2.ServerInfoResponse()
info.plugin_control.allowed_plugins.append(scalars_metadata.PLUGIN_NAME)
actual = server_info.allowed_plugins(info)
self.assertEqual(actual, frozenset([scalars_metadata.PLUGIN_NAME]))

def test_more_plugins(self):
info = server_info_pb2.ServerInfoResponse()
info.plugin_control.allowed_plugins.append("foo")
info.plugin_control.allowed_plugins.append("bar")
info.plugin_control.allowed_plugins.append("foo")
actual = server_info.allowed_plugins(info)
self.assertEqual(actual, frozenset(["foo", "bar"]))


def _localhost():
"""Gets family and nodename for a loopback address."""
s = socket
Expand Down
34 changes: 26 additions & 8 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self,
writer_client,
logdir,
allowed_plugins,
rpc_rate_limiter=None,
name=None,
description=None,
Expand All @@ -78,6 +79,9 @@ def __init__(
Args:
writer_client: a TensorBoardWriterService stub instance
logdir: path of the log directory to upload
allowed_plugins: collection of string plugin names; events will only
be uploaded if their time series's metadata specifies one of these
plugin names
rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency.
Note this limit applies at the level of single RPCs in the Scalar
and Tensor case, but at the level of an entire blob upload in the
Expand All @@ -90,6 +94,7 @@ def __init__(
"""
self._api = writer_client
self._logdir = logdir
self._allowed_plugins = frozenset(allowed_plugins)
self._name = name
self._description = description
self._request_sender = None
Expand Down Expand Up @@ -122,7 +127,10 @@ def create_experiment(self):
self._api.CreateExperiment, request
)
self._request_sender = _BatchedRequestSender(
response.experiment_id, self._api, self._rpc_rate_limiter
response.experiment_id,
self._api,
allowed_plugins=self._allowed_plugins,
rpc_rate_limiter=self._rpc_rate_limiter,
)
return response.experiment_id

Expand Down Expand Up @@ -261,12 +269,13 @@ class _BatchedRequestSender(object):
calling its methods concurrently.
"""

def __init__(self, experiment_id, api, rpc_rate_limiter):
def __init__(self, experiment_id, api, allowed_plugins, rpc_rate_limiter):
# Map from `(run_name, tag_name)` to `SummaryMetadata` if the time
# series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`.
self._tag_metadata = {}
self._allowed_plugins = frozenset(allowed_plugins)
self._scalar_request_sender = _ScalarBatchedRequestSender(
experiment_id, api, rpc_rate_limiter
experiment_id, api, rpc_rate_limiter,
)

# TODO(nielsene): add tensor case here
Expand Down Expand Up @@ -295,23 +304,32 @@ def send_requests(self, run_to_events):
# If later events arrive with a mismatching plugin_name, they are
# ignored with a warning.
metadata = self._tag_metadata.get(time_series_key)
first_in_time_series = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mechanism of detecting 'first time' is a little entangled hard to follow. How about maintaining a collection of plugin_names_seen and testing if plugin_name in plugin_names_seen:?

Related, the 'first time' gets reset for every _upload_once(). Is that desirable, or should we maintain a singleton plugin_names_seen (or similar) that persists over cycles? In this case you'd change the error message too, of course, so it's not time-series-specific but just appears once per plugin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mechanism of detecting 'first time' is a little entangled hard to
follow. How about maintaining a collection of plugin_names_seen and
testing if plugin_name in plugin_names_seen:?

We warn once per time series, not once per plugin. So rather than a
collection of “plugin names seen”, we maintain a collection of “time
series seen”—but this is just the key set of self._tag_metadata.

Related, the 'first time' gets reset for every _upload_once().

Hmm; I’m not able to reproduce that? We create a _BatchedRequestSender
only once, at create_experiment time, and the “already seen?” source
of truth is on its self._tag_metadata instance attribute, which is
never cleared.

It may help to look at the log output to address both questions. The
output looks like:

directory_loader.py:131] Loading data from path /HOMEDIR/tensorboard_data/mnist/lr_1E-03,conv=1,fc=2/events.out.tfevents.1563406327.HOSTNAME
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'input/image/0') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'input/image/1') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'input/image/2') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'conv/weights') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'conv/biases') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'conv/activations') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc1/weights') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc1/biases') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc1/activations') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc1/relu') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc2/weights') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc2/biases') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=1,fc=2', 'fc2/activations') with unsupported plugin name 'histograms'
uploader.py:548] Trying request of 8154 bytes
uploader.py:555] Upload for 1 runs (8154 bytes) took 0.254 seconds
uploader.py:548] Trying request of 8166 bytes
uploader.py:555] Upload for 1 runs (8166 bytes) took 0.311 seconds
directory_loader.py:131] Loading data from path /HOMEDIR/tensorboard_data/mnist/lr_1E-03,conv=2,fc=2/events.out.tfevents.1563406405.HOSTNAME
uploader.py:345] Skipping time series ('lr_1E-03,conv=2,fc=2', 'input/image/0') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=2,fc=2', 'input/image/1') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=2,fc=2', 'input/image/2') with unsupported plugin name 'images'
uploader.py:345] Skipping time series ('lr_1E-03,conv=2,fc=2', 'conv1/weights') with unsupported plugin name 'histograms'
uploader.py:345] Skipping time series ('lr_1E-03,conv=2,fc=2', 'conv1/biases') with unsupported plugin name 'histograms'

Note that each time series (run-tag combination) appears only once, even
though we process the some time series in multiple upload cycles. For
instance, the second upload cycle here doesn’t print any “Skipping…”
messages, and the messages in the third upload cycle are for new time
series (conv=1 vs. conv=2).

if metadata is None:
first_in_time_series = True
metadata = value.metadata
self._tag_metadata[time_series_key] = metadata

plugin_name = metadata.plugin_data.plugin_name
if value.HasField("metadata") and (
value.metadata.plugin_data.plugin_name
!= metadata.plugin_data.plugin_name
plugin_name != value.metadata.plugin_data.plugin_name
):
logger.warning(
"Mismatching plugin names for %s. Expected %s, found %s.",
time_series_key,
metadata.plugin_data.plugin_name,
value.metadata.plugin_data.plugin_name,
)
elif (
metadata.plugin_data.plugin_name == scalar_metadata.PLUGIN_NAME
):
continue
if plugin_name not in self._allowed_plugins:
if first_in_time_series:
logger.info(
"Skipping time series %r with unsupported plugin name %r",
time_series_key,
plugin_name,
)
continue
if plugin_name == scalar_metadata.PLUGIN_NAME:
self._scalar_request_sender.add_event(
run_name, event, value, metadata
)
Expand Down
2 changes: 2 additions & 0 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def _run(flags):
except server_info_lib.CommunicationError as e:
_die(str(e))
_handle_server_info(server_info)
logging.info("Received server info: <%r>", server_info)

if not server_info.api_server.endpoint:
logging.error("Server info response: %s", server_info)
Expand Down Expand Up @@ -593,6 +594,7 @@ def execute(self, server_info, channel):
uploader = uploader_lib.TensorBoardUploader(
api_client,
self.logdir,
allowed_plugins=server_info_lib.allowed_plugins(server_info),
name=self.name,
description=self.description,
)
Expand Down
41 changes: 37 additions & 4 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import summary_v2 as histogram_v2
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.scalar import summary_v2 as scalar_v2
from tensorboard.summary import v1 as summary_v1
from tensorboard.util import test_util as tb_test_util
Expand All @@ -68,6 +69,9 @@ def _create_mock_client():
return mock_client


_SCALARS_ONLY = frozenset((scalars_metadata.PLUGIN_NAME,))


# Sentinel for `_create_*` helpers, for arguments for which we want to
# supply a default other than the `None` used by the code under test.
_USE_DEFAULT = object()
Expand All @@ -76,32 +80,44 @@ def _create_mock_client():
def _create_uploader(
writer_client=_USE_DEFAULT,
logdir=None,
allowed_plugins=_USE_DEFAULT,
rpc_rate_limiter=_USE_DEFAULT,
name=None,
description=None,
):
if writer_client is _USE_DEFAULT:
writer_client = _create_mock_client()
if allowed_plugins is _USE_DEFAULT:
allowed_plugins = _SCALARS_ONLY
if rpc_rate_limiter is _USE_DEFAULT:
rpc_rate_limiter = util.RateLimiter(0)
return uploader_lib.TensorBoardUploader(
writer_client,
logdir,
allowed_plugins=allowed_plugins,
rpc_rate_limiter=rpc_rate_limiter,
name=name,
description=description,
)


def _create_request_sender(
experiment_id=None, api=None, rpc_rate_limiter=_USE_DEFAULT
experiment_id=None,
api=None,
allowed_plugins=_USE_DEFAULT,
rpc_rate_limiter=_USE_DEFAULT,
):
if api is _USE_DEFAULT:
api = _create_mock_client()
if allowed_plugins is _USE_DEFAULT:
allowed_plugins = _SCALARS_ONLY
if rpc_rate_limiter is _USE_DEFAULT:
rpc_rate_limiter = util.RateLimiter(0)
return uploader_lib._BatchedRequestSender(
experiment_id=experiment_id, api=api, rpc_rate_limiter=rpc_rate_limiter
experiment_id=experiment_id,
api=api,
allowed_plugins=allowed_plugins,
rpc_rate_limiter=rpc_rate_limiter,
)


Expand Down Expand Up @@ -376,9 +392,15 @@ def test_upload_full_logdir(self):


class BatchedRequestSenderTest(tf.test.TestCase):
def _populate_run_from_events(self, run_proto, events):
def _populate_run_from_events(
self, run_proto, events, allowed_plugins=_USE_DEFAULT
):
mock_client = _create_mock_client()
builder = _create_request_sender(experiment_id="123", api=mock_client)
builder = _create_request_sender(
experiment_id="123",
api=mock_client,
allowed_plugins=allowed_plugins,
)
builder.send_requests({"": events})
requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list]
if requests:
Expand Down Expand Up @@ -471,6 +493,17 @@ def test_skips_non_scalar_events_in_scalar_time_series(self):
tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags}
self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1})

def test_skips_events_from_disallowed_plugins(self):
event = event_pb2.Event(
step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0)
)
run_proto = write_service_pb2.WriteScalarRequest.Run()
self._populate_run_from_events(
run_proto, [event], allowed_plugins=frozenset("not-scalars")
)
expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
self.assertProtoEquals(run_proto, expected_run_proto)

def test_remembers_first_metadata_in_scalar_time_series(self):
scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0))
scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0))
Expand Down