diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 5161a82518..c3cbfe3bc7 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -399,6 +399,14 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "expect_requests_installed", + # This is a dummy rule used as a requests dependency in open-source. + # We expect requests to already be installed on the system, e.g., via + # `pip install requests`. + visibility = ["//visibility:public"], +) + filegroup( name = "tf_web_library_default_typings", srcs = [ diff --git a/tensorboard/pip_package/setup.py b/tensorboard/pip_package/setup.py index b7b883b8c9..7e6b56375c 100644 --- a/tensorboard/pip_package/setup.py +++ b/tensorboard/pip_package/setup.py @@ -32,6 +32,7 @@ 'markdown >= 2.6.8', 'numpy >= 1.12.0', 'protobuf >= 3.6.0', + 'requests >= 2.22.0, < 3', 'setuptools >= 41.0.0', 'six >= 1.10.0', 'werkzeug >= 0.11.15', diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index a0b25720ba..babe61f2d6 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -201,3 +201,29 @@ py_test( "//tensorboard:test", ], ) + +py_library( + name = "server_info", + srcs = ["server_info.py"], + deps = [ + "//tensorboard:expect_requests_installed", + "//tensorboard:version", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "@com_google_protobuf//:protobuf_python", + ], +) + +py_test( + name = "server_info_test", + size = "medium", # local network requests + timeout = "short", + srcs = ["server_info_test.py"], + deps = [ + ":server_info", + "//tensorboard:expect_futures_installed", + "//tensorboard:test", + "//tensorboard:version", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "@org_pocoo_werkzeug", + ], +) diff --git a/tensorboard/uploader/proto/BUILD b/tensorboard/uploader/proto/BUILD index 9353b897e6..bb702af388 100644 --- a/tensorboard/uploader/proto/BUILD +++ b/tensorboard/uploader/proto/BUILD @@ -6,11 +6,13 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +# TODO(@wchargin): Split more granularly. tb_proto_library( name = "protos_all", srcs = [ "export_service.proto", "scalar.proto", + "server_info.proto", "write_service.proto", ], has_services = True, diff --git a/tensorboard/uploader/proto/server_info.proto b/tensorboard/uploader/proto/server_info.proto new file mode 100644 index 0000000000..ba2a592f24 --- /dev/null +++ b/tensorboard/uploader/proto/server_info.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package tensorboard.service; + +// Request sent by uploader clients at the start of an upload session. Used to +// determine whether the client is recent enough to communicate with the +// server, and to receive any metadata needed for the upload session. +message ServerInfoRequest { + // Client-side TensorBoard version, per `tensorboard.version.VERSION`. + string version = 1; +} + +message ServerInfoResponse { + // Primary bottom-line: is the server compatible with the client, and is + // there anything that the end user should be aware of? + Compatibility compatibility = 1; + // Identifier for a gRPC server providing the `TensorBoardExporterService` and + // `TensorBoardWriterService` services (under the `tensorboard.service` proto + // package). + ApiServer api_server = 2; + // How to generate URLs to experiment pages. + ExperimentUrlFormat url_format = 3; +} + +enum CompatibilityVerdict { + VERDICT_UNKNOWN = 0; + // All is well. The client may proceed. + VERDICT_OK = 1; + // The client may proceed, but should heed the accompanying message. This + // may be the case if the user is on a version of TensorBoard that will + // soon be unsupported, or if the server is experiencing transient issues. + VERDICT_WARN = 2; + // The client should cease further communication with the server and abort + // operation after printing the accompanying `details` message. + VERDICT_ERROR = 3; +} + +message Compatibility { + CompatibilityVerdict verdict = 1; + // Human-readable message to display. When non-empty, will be displayed in + // all cases, even when the client may proceed. + string details = 2; +} + +message ApiServer { + // gRPC server URI: . + // For example: "api.tensorboard.dev:443". + string endpoint = 1; +} + +message ExperimentUrlFormat { + // Template string for experiment URLs. All occurrences of the value of the + // `id_placeholder` field in this template string should be replaced with an + // experiment ID. For example, if `id_placeholder` is "{{EID}}", then + // `template` might be "https://tensorboard.dev/experiment/{{EID}}/". + // Should be absolute. + string template = 1; + // Placeholder string that should be replaced with an actual experiment ID. + // (See docs for `template` field.) + string id_placeholder = 2; +} diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py new file mode 100644 index 0000000000..7ea7f75d51 --- /dev/null +++ b/tensorboard/uploader/server_info.py @@ -0,0 +1,100 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Initial server communication to determine session parameters.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from google.protobuf import message +import requests + +from tensorboard import version +from tensorboard.uploader.proto import server_info_pb2 + + +# Request timeout for communicating with remote server. +_REQUEST_TIMEOUT_SECONDS = 10 + + +def _server_info_request(): + request = server_info_pb2.ServerInfoRequest() + request.version = version.VERSION + return request + + +def fetch_server_info(origin): + """Fetches server info from a remote server. + + Args: + origin: The server with which to communicate. Should be a string + like "https://tensorboard.dev", including protocol, host, and (if + needed) port. + + Returns: + A `server_info_pb2.ServerInfoResponse` message. + + Raises: + CommunicationError: Upon failure to connect to or successfully + communicate with the remote server. + """ + endpoint = "%s/api/uploader" % origin + post_body = _server_info_request().SerializeToString() + try: + response = requests.post( + endpoint, data=post_body, timeout=_REQUEST_TIMEOUT_SECONDS + ) + except requests.RequestException as e: + raise CommunicationError("Failed to connect to backend: %s" % e) + if not response.ok: + raise CommunicationError( + "Non-OK status from backend (%d %s): %r" + % (response.status_code, response.reason, response.content) + ) + try: + return server_info_pb2.ServerInfoResponse.FromString(response.content) + except message.DecodeError as e: + raise CommunicationError( + "Corrupt response from backend (%s): %r" % (e, response.content) + ) + + +def create_server_info(frontend_origin, api_endpoint): + """Manually creates server info given a frontend and backend. + + Args: + frontend_origin: The origin of the TensorBoard.dev frontend, like + "https://tensorboard.dev" or "http://localhost:8000". + api_endpoint: As to `server_info_pb2.ApiServer.endpoint`. + + Returns: + A `server_info_pb2.ServerInfoResponse` message. + """ + result = server_info_pb2.ServerInfoResponse() + result.compatibility.verdict = server_info_pb2.VERDICT_OK + result.api_server.endpoint = api_endpoint + url_format = result.url_format + placeholder = "{{EID}}" + while placeholder in frontend_origin: + placeholder = "{%s}" % placeholder + url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder) + url_format.id_placeholder = placeholder + return result + + +class CommunicationError(RuntimeError): + """Raised upon failure to communicate with the server.""" + + pass diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py new file mode 100644 index 0000000000..8ce2537d29 --- /dev/null +++ b/tensorboard/uploader/server_info_test.py @@ -0,0 +1,156 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.server_info.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import os +import socket +from wsgiref import simple_server + +from concurrent import futures +from werkzeug import wrappers + +from tensorboard import test as tb_test +from tensorboard import version +from tensorboard.uploader import server_info +from tensorboard.uploader.proto import server_info_pb2 + + +class FetchServerInfoTest(tb_test.TestCase): + """Tests for `fetch_server_info`.""" + + def _start_server(self, app): + """Starts a server and returns its origin ("http://localhost:PORT").""" + (_, localhost) = _localhost() + server_class = _make_ipv6_compatible_wsgi_server() + server = simple_server.make_server(localhost, 0, app, server_class) + executor = futures.ThreadPoolExecutor() + future = executor.submit(server.serve_forever, poll_interval=0.01) + + def cleanup(): + server.shutdown() # stop handling requests + server.server_close() # release port + future.result(timeout=3) # wait for server termination + + self.addCleanup(cleanup) + return "http://localhost:%d" % server.server_port + + def test_fetches_response(self): + expected_result = server_info_pb2.ServerInfoResponse() + expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK + expected_result.compatibility.details = "all clear" + expected_result.api_server.endpoint = "api.example.com:443" + expected_result.url_format.template = "http://localhost:8080/{{eid}}" + expected_result.url_format.id_placeholder = "{{eid}}" + + @wrappers.BaseRequest.application + def app(request): + self.assertEqual(request.method, "POST") + self.assertEqual(request.path, "/api/uploader") + body = request.get_data() + request_pb = server_info_pb2.ServerInfoRequest.FromString(body) + self.assertEqual(request_pb.version, version.VERSION) + return wrappers.BaseResponse(expected_result.SerializeToString()) + + origin = self._start_server(app) + result = server_info.fetch_server_info(origin) + self.assertEqual(result, expected_result) + + def test_econnrefused(self): + (family, localhost) = _localhost() + s = socket.socket(family) + s.bind((localhost, 0)) + self.addCleanup(s.close) + port = s.getsockname()[1] + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info("http://localhost:%d" % port) + msg = str(cm.exception) + self.assertIn("Failed to connect to backend", msg) + if os.name != "nt": + self.assertIn(os.strerror(errno.ECONNREFUSED), msg) + + def test_non_ok_response(self): + @wrappers.BaseRequest.application + def app(request): + del request # unused + return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway") + + origin = self._start_server(app) + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info(origin) + msg = str(cm.exception) + self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg) + self.assertIn("very sad", msg) + + def test_corrupt_response(self): + @wrappers.BaseRequest.application + def app(request): + del request # unused + return wrappers.BaseResponse(b"an unlikely proto") + + origin = self._start_server(app) + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info(origin) + msg = str(cm.exception) + self.assertIn("Corrupt response from backend", msg) + self.assertIn("an unlikely proto", msg) + + +class CreateServerInfoTest(tb_test.TestCase): + """Tests for `create_server_info`.""" + + def test(self): + frontend = "http://localhost:8080" + backend = "localhost:10000" + result = server_info.create_server_info(frontend, backend) + + expected_compatibility = server_info_pb2.Compatibility() + expected_compatibility.verdict = server_info_pb2.VERDICT_OK + expected_compatibility.details = "" + self.assertEqual(result.compatibility, expected_compatibility) + + expected_api_server = server_info_pb2.ApiServer() + expected_api_server.endpoint = backend + self.assertEqual(result.api_server, expected_api_server) + + url_format = result.url_format + actual_url = url_format.template.replace(url_format.id_placeholder, "123") + expected_url = "http://localhost:8080/experiment/123/" + self.assertEqual(actual_url, expected_url) + + +def _localhost(): + """Gets family and nodename for a loopback address.""" + s = socket + infos = s.getaddrinfo(None, 0, s.AF_UNSPEC, s.SOCK_STREAM, 0, s.AI_ADDRCONFIG) + (family, _, _, _, address) = infos[0] + nodename = address[0] + return (family, nodename) + + +def _make_ipv6_compatible_wsgi_server(): + """Creates a `WSGIServer` subclass that works on IPv6-only machines.""" + address_family = _localhost()[0] + attrs = {"address_family": address_family} + bases = (simple_server.WSGIServer, object) # `object` needed for py2 + return type("_Ipv6CompatibleWsgiServer", bases, attrs) + + +if __name__ == "__main__": + tb_test.main()