diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index babe61f2d6..0b17a7e0dc 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -56,6 +56,7 @@ py_library( ":auth", ":dev_creds", ":exporter_lib", + ":server_info", ":uploader_lib", "//tensorboard:expect_absl_app_installed", "//tensorboard:expect_absl_flags_argparse_flags_installed", diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py index e36a7e8ac2..5906bb11b3 100644 --- a/tensorboard/uploader/server_info.py +++ b/tensorboard/uploader/server_info.py @@ -97,6 +97,20 @@ def create_server_info(frontend_origin, api_endpoint): return result +def experiment_url(server_info, experiment_id): + """Formats a URL that will resolve to the provided experiment. + + Args: + server_info: A `server_info_pb2.ServerInfoResponse` message. + experiment_id: A string; the ID of the experiment to link to. + + Returns: + A URL resolving to the given experiment, as a string. + """ + url_format = server_info.url_format + return url_format.template.replace(url_format.id_placeholder, experiment_id) + + class CommunicationError(RuntimeError): """Raised upon failure to communicate with the server.""" diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py index 85b615afbd..64a8bdfb00 100644 --- a/tensorboard/uploader/server_info_test.py +++ b/tensorboard/uploader/server_info_test.py @@ -147,6 +147,17 @@ def test(self): self.assertEqual(actual_url, expected_url) +class ExperimentUrlTest(tb_test.TestCase): + """Tests for `experiment_url`.""" + + def test(self): + info = server_info_pb2.ServerInfoResponse() + info.url_format.template = "https://unittest.tensorboard.dev/x/???" + info.url_format.id_placeholder = "???" + actual = server_info.experiment_url(info, "123") + self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123") + + def _localhost(): """Gets family and nodename for a loopback address.""" s = socket diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 222a23977c..5cb780b1f5 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -89,12 +89,12 @@ def __init__(self, writer_client, logdir, rate_limiter=None): self._logdir, directory_loader_factory) def create_experiment(self): - """Creates an Experiment for this upload session and returns the URL.""" + """Creates an Experiment for this upload session and returns the ID.""" logger.info("Creating experiment") request = write_service_pb2.CreateExperimentRequest() response = grpc_util.call_with_retries(self._api.CreateExperiment, request) self._request_builder = _RequestBuilder(response.experiment_id) - return response.url + return response.experiment_id def start_uploading(self): """Blocks forever to continuously upload data from the logdir. diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index ad7ee1b63d..4f2c36821c 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -35,7 +35,9 @@ from tensorboard.uploader.proto import write_service_pb2_grpc from tensorboard.uploader import auth from tensorboard.uploader import exporter as exporter_lib +from tensorboard.uploader import server_info as server_info_lib from tensorboard.uploader import uploader as uploader_lib +from tensorboard.uploader.proto import server_info_pb2 from tensorboard import program from tensorboard.plugins import base_plugin @@ -65,6 +67,11 @@ _AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth' _AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE' +_DEFAULT_ORIGIN = "https://tensorboard.dev" +# Compatibility measure until server-side /api/uploader support is +# rolled out and stable. +_HARDCODED_API_ENDPOINT = "api.tensorboard.dev:443" + def _prompt_for_user_ack(intent): """Prompts for user consent, exiting the program if they decline.""" @@ -91,10 +98,19 @@ def _define_flags(parser): subparsers = parser.add_subparsers() parser.add_argument( - '--endpoint', + '--origin', type=str, - default='api.tensorboard.dev:443', - help='URL for the API server accepting write requests.') + default='', + help='Experimental. Origin for TensorBoard.dev service to which ' + 'to connect. If not set, defaults to %r.' % _DEFAULT_ORIGIN) + + parser.add_argument( + '--api_endpoint', + type=str, + default='', + help='Experimental. Direct URL for the API server accepting ' + 'write requests. If set, will skip initial server handshake ' + 'unless `--origin` is also set.') parser.add_argument( '--grpc_creds_type', @@ -222,15 +238,26 @@ def _run(flags): msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type raise base_plugin.FlagsError(msg) + try: + server_info = _get_server_info(flags) + except server_info_lib.CommunicationError as e: + _die(str(e)) + _handle_server_info(server_info) + + if not server_info.api_server.endpoint: + logging.error('Server info response: %s', server_info) + _die('Internal error: frontend did not specify an API server') composite_channel_creds = grpc.composite_channel_credentials( channel_creds, auth.id_token_call_credentials(credentials)) # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until # logdir exists to open channel. channel = grpc.secure_channel( - flags.endpoint, composite_channel_creds, options=channel_options) + server_info.api_server.endpoint, + composite_channel_creds, + options=channel_options) with channel: - intent.execute(channel) + intent.execute(server_info, channel) @six.add_metaclass(abc.ABCMeta) @@ -254,10 +281,11 @@ def get_ack_message_body(self): pass @abc.abstractmethod - def execute(self, channel): + def execute(self, server_info, channel): """Carries out this intent with the specified gRPC channel. Args: + server_info: A `server_info_pb2.ServerInfoResponse` value. channel: A connected gRPC channel whose server provides the TensorBoard reader and writer services. """ @@ -271,7 +299,7 @@ def get_ack_message_body(self): """Must not be called.""" raise AssertionError('No user ack needed to revoke credentials') - def execute(self, channel): + def execute(self, server_info, channel): """Execute handled specially by `main`. Must not be called.""" raise AssertionError('_AuthRevokeIntent should not be directly executed') @@ -296,7 +324,7 @@ def __init__(self, experiment_id): def get_ack_message_body(self): return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id) - def execute(self, channel): + def execute(self, server_info, channel): api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) experiment_id = self.experiment_id if not experiment_id: @@ -329,14 +357,13 @@ class _ListIntent(_Intent): def get_ack_message_body(self): return self._MESSAGE - def execute(self, channel): + def execute(self, server_info, channel): api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel) gen = exporter_lib.list_experiments(api_client) count = 0 for experiment_id in gen: count += 1 - # TODO(@wchargin): Once #2879 is in, remove this hard-coded URL pattern. - url = 'https://tensorboard.dev/experiment/%s/' % experiment_id + url = server_info_lib.experiment_url(server_info, experiment_id) print(url) sys.stdout.flush() if not count: @@ -366,10 +393,11 @@ def __init__(self, logdir): def get_ack_message_body(self): return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) - def execute(self, channel): + def execute(self, server_info, channel): api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir) - url = uploader.create_experiment() + experiment_id = uploader.create_experiment() + url = server_info_lib.experiment_url(server_info, experiment_id) print("Upload started and will continue reading any new data as it's added") print("to the logdir. To stop uploading, press Ctrl-C.") print("View your TensorBoard live at: %s" % url) @@ -407,7 +435,7 @@ def __init__(self, output_dir): def get_ack_message_body(self): return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir) - def execute(self, channel): + def execute(self, server_info, channel): api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel) outdir = self.output_dir try: @@ -476,6 +504,34 @@ def _get_intent(flags): raise AssertionError('Unknown subcommand %r' % (cmd,)) +def _get_server_info(flags): + origin = flags.origin or _DEFAULT_ORIGIN + if not flags.origin: + # Temporary fallback to hardcoded API endpoint when not specified. + api_endpoint = flags.api_endpoint or _HARDCODED_API_ENDPOINT + return server_info_lib.create_server_info(origin, api_endpoint) + server_info = server_info_lib.fetch_server_info(origin) + # Override with any API server explicitly specified on the command + # line, but only if the server accepted our initial handshake. + if flags.api_endpoint and server_info.api_server.endpoint: + server_info.api_server.endpoint = flags.api_endpoint + return server_info + + +def _handle_server_info(info): + compat = info.compatibility + if compat.verdict == server_info_pb2.VERDICT_WARN: + sys.stderr.write('Warning [from server]: %s\n' % compat.details) + sys.stderr.flush() + elif compat.verdict == server_info_pb2.VERDICT_ERROR: + _die('Error [from server]: %s' % compat.details) + else: + # OK or unknown; assume OK. + if compat.details: + sys.stderr.write('%s\n' % compat.details) + sys.stderr.flush() + + def _die(message): sys.stderr.write('%s\n' % (message,)) sys.stderr.flush() diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index 427ce45ac8..f1abefb1cc 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -61,7 +61,7 @@ def _create_mock_client(self): stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) mock_client = mock.create_autospec(stub) fake_exp_response = write_service_pb2.CreateExperimentResponse( - experiment_id="123", url="https://example.com/123") + experiment_id="123", url="should not be used!") mock_client.CreateExperiment.return_value = fake_exp_response return mock_client @@ -69,8 +69,8 @@ def test_create_experiment(self): logdir = "/logs/foo" mock_client = self._create_mock_client() uploader = uploader_lib.TensorBoardUploader(mock_client, logdir) - url = uploader.create_experiment() - self.assertEqual(url, "https://example.com/123") + eid = uploader.create_experiment() + self.assertEqual(eid, "123") def test_start_uploading_without_create_experiment_fails(self): mock_client = self._create_mock_client()