diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index f845a0e0cb..e3ac1fe762 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -103,6 +103,7 @@ py_test( "//tensorboard:expect_tensorflow_installed", "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/histogram:summary_v2", + "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:summary_v2", "//tensorboard/summary:summary_v1", "//tensorboard/uploader/proto:protos_all_py_pb2", @@ -193,6 +194,7 @@ py_library( deps = [ "//tensorboard:expect_requests_installed", "//tensorboard:version", + "//tensorboard/plugins/scalar:metadata", "//tensorboard/uploader/proto:protos_all_py_pb2", "@com_google_protobuf//:protobuf_python", ], @@ -208,6 +210,7 @@ py_test( "//tensorboard:expect_futures_installed", "//tensorboard:test", "//tensorboard:version", + "//tensorboard/plugins/scalar:metadata", "//tensorboard/uploader/proto:protos_all_py_pb2", "@org_pocoo_werkzeug", ], diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py index e381f38bb8..f031830dcc 100644 --- a/tensorboard/uploader/server_info.py +++ b/tensorboard/uploader/server_info.py @@ -22,6 +22,7 @@ import requests from tensorboard import version +from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.uploader.proto import server_info_pb2 @@ -111,6 +112,28 @@ def experiment_url(server_info, experiment_id): return url_format.template.replace(url_format.id_placeholder, experiment_id) +def allowed_plugins(server_info): + """Determines which plugins may upload data. + + This pulls from the `plugin_control` on the `server_info` when that + submessage is set, else falls back to a default. + + Args: + server_info: A `server_info_pb2.ServerInfoResponse` message. + + Returns: + A `frozenset` of plugin names. + """ + if server_info.HasField("plugin_control"): + return frozenset(server_info.plugin_control.allowed_plugins) + else: + # Old server: gracefully degrade to scalars only, which have + # been supported since launch. TODO(@wchargin): Promote this + # branch to an error once we're confident that we won't roll + # back to old server versions. + return frozenset((scalars_metadata.PLUGIN_NAME,)) + + class CommunicationError(RuntimeError): """Raised upon failure to communicate with the server.""" diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py index d891a150ff..129bf36b2a 100644 --- a/tensorboard/uploader/server_info_test.py +++ b/tensorboard/uploader/server_info_test.py @@ -28,6 +28,7 @@ from tensorboard import test as tb_test from tensorboard import version +from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.uploader import server_info from tensorboard.uploader.proto import server_info_pb2 @@ -163,6 +164,35 @@ def test(self): self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123") +class AllowedPluginsTest(tb_test.TestCase): + """Tests for `allowed_plugins`.""" + + def test_old_server_no_plugins(self): + info = server_info_pb2.ServerInfoResponse() + actual = server_info.allowed_plugins(info) + self.assertEqual(actual, frozenset([scalars_metadata.PLUGIN_NAME])) + + def test_provided_but_no_plugins(self): + info = server_info_pb2.ServerInfoResponse() + info.plugin_control.SetInParent() + actual = server_info.allowed_plugins(info) + self.assertEqual(actual, frozenset([])) + + def test_scalars_only(self): + info = server_info_pb2.ServerInfoResponse() + info.plugin_control.allowed_plugins.append(scalars_metadata.PLUGIN_NAME) + actual = server_info.allowed_plugins(info) + self.assertEqual(actual, frozenset([scalars_metadata.PLUGIN_NAME])) + + def test_more_plugins(self): + info = server_info_pb2.ServerInfoResponse() + info.plugin_control.allowed_plugins.append("foo") + info.plugin_control.allowed_plugins.append("bar") + info.plugin_control.allowed_plugins.append("foo") + actual = server_info.allowed_plugins(info) + self.assertEqual(actual, frozenset(["foo", "bar"])) + + def _localhost(): """Gets family and nodename for a loopback address.""" s = socket diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 05c1344390..2f50ec2410 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -69,6 +69,7 @@ def __init__( self, writer_client, logdir, + allowed_plugins, rpc_rate_limiter=None, name=None, description=None, @@ -78,6 +79,9 @@ def __init__( Args: writer_client: a TensorBoardWriterService stub instance logdir: path of the log directory to upload + allowed_plugins: collection of string plugin names; events will only + be uploaded if their time series's metadata specifies one of these + plugin names rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency. Note this limit applies at the level of single RPCs in the Scalar and Tensor case, but at the level of an entire blob upload in the @@ -90,6 +94,7 @@ def __init__( """ self._api = writer_client self._logdir = logdir + self._allowed_plugins = frozenset(allowed_plugins) self._name = name self._description = description self._request_sender = None @@ -122,7 +127,10 @@ def create_experiment(self): self._api.CreateExperiment, request ) self._request_sender = _BatchedRequestSender( - response.experiment_id, self._api, self._rpc_rate_limiter + response.experiment_id, + self._api, + allowed_plugins=self._allowed_plugins, + rpc_rate_limiter=self._rpc_rate_limiter, ) return response.experiment_id @@ -261,12 +269,13 @@ class _BatchedRequestSender(object): calling its methods concurrently. """ - def __init__(self, experiment_id, api, rpc_rate_limiter): + def __init__(self, experiment_id, api, allowed_plugins, rpc_rate_limiter): # Map from `(run_name, tag_name)` to `SummaryMetadata` if the time # series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`. self._tag_metadata = {} + self._allowed_plugins = frozenset(allowed_plugins) self._scalar_request_sender = _ScalarBatchedRequestSender( - experiment_id, api, rpc_rate_limiter + experiment_id, api, rpc_rate_limiter, ) # TODO(nielsene): add tensor case here @@ -295,13 +304,15 @@ def send_requests(self, run_to_events): # If later events arrive with a mismatching plugin_name, they are # ignored with a warning. metadata = self._tag_metadata.get(time_series_key) + first_in_time_series = False if metadata is None: + first_in_time_series = True metadata = value.metadata self._tag_metadata[time_series_key] = metadata + plugin_name = metadata.plugin_data.plugin_name if value.HasField("metadata") and ( - value.metadata.plugin_data.plugin_name - != metadata.plugin_data.plugin_name + plugin_name != value.metadata.plugin_data.plugin_name ): logger.warning( "Mismatching plugin names for %s. Expected %s, found %s.", @@ -309,9 +320,16 @@ def send_requests(self, run_to_events): metadata.plugin_data.plugin_name, value.metadata.plugin_data.plugin_name, ) - elif ( - metadata.plugin_data.plugin_name == scalar_metadata.PLUGIN_NAME - ): + continue + if plugin_name not in self._allowed_plugins: + if first_in_time_series: + logger.info( + "Skipping time series %r with unsupported plugin name %r", + time_series_key, + plugin_name, + ) + continue + if plugin_name == scalar_metadata.PLUGIN_NAME: self._scalar_request_sender.add_event( run_name, event, value, metadata ) diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index 2693691dda..5d157f1e76 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -303,6 +303,7 @@ def _run(flags): except server_info_lib.CommunicationError as e: _die(str(e)) _handle_server_info(server_info) + logging.info("Received server info: <%r>", server_info) if not server_info.api_server.endpoint: logging.error("Server info response: %s", server_info) @@ -593,6 +594,7 @@ def execute(self, server_info, channel): uploader = uploader_lib.TensorBoardUploader( api_client, self.logdir, + allowed_plugins=server_info_lib.allowed_plugins(server_info), name=self.name, description=self.description, ) diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index 3b926736cf..760b7d8269 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -43,6 +43,7 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.histogram import summary_v2 as histogram_v2 +from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.scalar import summary_v2 as scalar_v2 from tensorboard.summary import v1 as summary_v1 from tensorboard.util import test_util as tb_test_util @@ -68,6 +69,9 @@ def _create_mock_client(): return mock_client +_SCALARS_ONLY = frozenset((scalars_metadata.PLUGIN_NAME,)) + + # Sentinel for `_create_*` helpers, for arguments for which we want to # supply a default other than the `None` used by the code under test. _USE_DEFAULT = object() @@ -76,17 +80,21 @@ def _create_mock_client(): def _create_uploader( writer_client=_USE_DEFAULT, logdir=None, + allowed_plugins=_USE_DEFAULT, rpc_rate_limiter=_USE_DEFAULT, name=None, description=None, ): if writer_client is _USE_DEFAULT: writer_client = _create_mock_client() + if allowed_plugins is _USE_DEFAULT: + allowed_plugins = _SCALARS_ONLY if rpc_rate_limiter is _USE_DEFAULT: rpc_rate_limiter = util.RateLimiter(0) return uploader_lib.TensorBoardUploader( writer_client, logdir, + allowed_plugins=allowed_plugins, rpc_rate_limiter=rpc_rate_limiter, name=name, description=description, @@ -94,14 +102,22 @@ def _create_uploader( def _create_request_sender( - experiment_id=None, api=None, rpc_rate_limiter=_USE_DEFAULT + experiment_id=None, + api=None, + allowed_plugins=_USE_DEFAULT, + rpc_rate_limiter=_USE_DEFAULT, ): if api is _USE_DEFAULT: api = _create_mock_client() + if allowed_plugins is _USE_DEFAULT: + allowed_plugins = _SCALARS_ONLY if rpc_rate_limiter is _USE_DEFAULT: rpc_rate_limiter = util.RateLimiter(0) return uploader_lib._BatchedRequestSender( - experiment_id=experiment_id, api=api, rpc_rate_limiter=rpc_rate_limiter + experiment_id=experiment_id, + api=api, + allowed_plugins=allowed_plugins, + rpc_rate_limiter=rpc_rate_limiter, ) @@ -376,9 +392,15 @@ def test_upload_full_logdir(self): class BatchedRequestSenderTest(tf.test.TestCase): - def _populate_run_from_events(self, run_proto, events): + def _populate_run_from_events( + self, run_proto, events, allowed_plugins=_USE_DEFAULT + ): mock_client = _create_mock_client() - builder = _create_request_sender(experiment_id="123", api=mock_client) + builder = _create_request_sender( + experiment_id="123", + api=mock_client, + allowed_plugins=allowed_plugins, + ) builder.send_requests({"": events}) requests = [c[0][0] for c in mock_client.WriteScalar.call_args_list] if requests: @@ -471,6 +493,17 @@ def test_skips_non_scalar_events_in_scalar_time_series(self): tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) + def test_skips_events_from_disallowed_plugins(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0) + ) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events( + run_proto, [event], allowed_plugins=frozenset("not-scalars") + ) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + self.assertProtoEquals(run_proto, expected_run_proto) + def test_remembers_first_metadata_in_scalar_time_series(self): scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0)) scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0))