Skip to content

Commit 9d3d367

Browse files
committed
uploader: add ServerInfo protos and logic (#2878)
Summary: This commit adds an RPC definition by which the uploader can connect to the frontend web server at the start of an upload session. This resolves a number of outstanding issues: - The frontend can tell the uploader which backend server to connect to, rather than requiring a hard-coded endpoint in the uploader. - The frontend can tell the uploader how to generate experiment URLs, rather than requiring the backend server to provide this information (which it can’t, really, in general). - The frontend can check whether the uploader client is recent enough and instruct the end user to update if it’s not. - The frontend can warn the user about transient issues in case the service is down, degraded, under maintenance, etc. An endpoint `https://tensorboard.dev/api/uploader` on the server will provide this information. Test Plan: Unit tests suffice. wchargin-branch: uploader-serverinfo-protos
1 parent 77f2eb7 commit 9d3d367

File tree

7 files changed

+354
-0
lines changed

7 files changed

+354
-0
lines changed

tensorboard/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ py_library(
373373
visibility = ["//visibility:public"],
374374
)
375375

376+
py_library(
377+
name = "expect_requests_installed",
378+
# This is a dummy rule used as a requests dependency in open-source.
379+
# We expect requests to already be installed on the system, e.g., via
380+
# `pip install requests`.
381+
visibility = ["//visibility:public"],
382+
)
383+
376384
filegroup(
377385
name = "tf_web_library_default_typings",
378386
srcs = [

tensorboard/pip_package/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
'markdown >= 2.6.8',
3333
'numpy >= 1.12.0',
3434
'protobuf >= 3.6.0',
35+
'requests >= 2.22.0, < 3',
3536
'setuptools >= 41.0.0',
3637
'six >= 1.10.0',
3738
'werkzeug >= 0.11.15',

tensorboard/uploader/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,29 @@ py_test(
201201
"//tensorboard:test",
202202
],
203203
)
204+
205+
py_library(
206+
name = "server_info",
207+
srcs = ["server_info.py"],
208+
deps = [
209+
"//tensorboard:expect_requests_installed",
210+
"//tensorboard:version",
211+
"//tensorboard/uploader/proto:protos_all_py_pb2",
212+
"@com_google_protobuf//:protobuf_python",
213+
],
214+
)
215+
216+
py_test(
217+
name = "server_info_test",
218+
size = "medium", # local network requests
219+
timeout = "short",
220+
srcs = ["server_info_test.py"],
221+
deps = [
222+
":server_info",
223+
"//tensorboard:expect_futures_installed",
224+
"//tensorboard:test",
225+
"//tensorboard:version",
226+
"//tensorboard/uploader/proto:protos_all_py_pb2",
227+
"@org_pocoo_werkzeug",
228+
],
229+
)

tensorboard/uploader/proto/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ licenses(["notice"]) # Apache 2.0
66

77
exports_files(["LICENSE"])
88

9+
# TODO(@wchargin): Split more granularly.
910
tb_proto_library(
1011
name = "protos_all",
1112
srcs = [
1213
"export_service.proto",
1314
"scalar.proto",
15+
"server_info.proto",
1416
"write_service.proto",
1517
],
1618
has_services = True,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
syntax = "proto3";
2+
3+
package tensorboard.service;
4+
5+
// Request sent by uploader clients at the start of an upload session. Used to
6+
// determine whether the client is recent enough to communicate with the
7+
// server, and to receive any metadata needed for the upload session.
8+
message ServerInfoRequest {
9+
// Client-side TensorBoard version, per `tensorboard.version.VERSION`.
10+
string version = 1;
11+
}
12+
13+
message ServerInfoResponse {
14+
// Primary bottom-line: is the server compatible with the client, and is
15+
// there anything that the end user should be aware of?
16+
Compatibility compatibility = 1;
17+
// Identifier for a gRPC server providing the `TensorBoardExporterService` and
18+
// `TensorBoardWriterService` services (under the `tensorboard.service` proto
19+
// package).
20+
ApiServer api_server = 2;
21+
// How to generate URLs to experiment pages.
22+
ExperimentUrlFormat url_format = 3;
23+
}
24+
25+
enum CompatibilityVerdict {
26+
VERDICT_UNKNOWN = 0;
27+
// All is well. The client may proceed.
28+
VERDICT_OK = 1;
29+
// The client may proceed, but should heed the accompanying message. This
30+
// may be the case if the user is on a version of TensorBoard that will
31+
// soon be unsupported, or if the server is experiencing transient issues.
32+
VERDICT_WARN = 2;
33+
// The client should cease further communication with the server and abort
34+
// operation after printing the accompanying `details` message.
35+
VERDICT_ERROR = 3;
36+
}
37+
38+
message Compatibility {
39+
CompatibilityVerdict verdict = 1;
40+
// Human-readable message to display. When non-empty, will be displayed in
41+
// all cases, even when the client may proceed.
42+
string details = 2;
43+
}
44+
45+
message ApiServer {
46+
// gRPC server URI: <https://github.com/grpc/grpc/blob/master/doc/naming.md>.
47+
// For example: "api.tensorboard.dev:443".
48+
string endpoint = 1;
49+
}
50+
51+
message ExperimentUrlFormat {
52+
// Template string for experiment URLs. All occurrences of the value of the
53+
// `id_placeholder` field in this template string should be replaced with an
54+
// experiment ID. For example, if `id_placeholder` is "{{EID}}", then
55+
// `template` might be "https://tensorboard.dev/experiment/{{EID}}/".
56+
// Should be absolute.
57+
string template = 1;
58+
// Placeholder string that should be replaced with an actual experiment ID.
59+
// (See docs for `template` field.)
60+
string id_placeholder = 2;
61+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Initial server communication to determine session parameters."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from google.protobuf import message
22+
import requests
23+
24+
from tensorboard import version
25+
from tensorboard.uploader.proto import server_info_pb2
26+
27+
28+
# Request timeout for communicating with remote server.
29+
_REQUEST_TIMEOUT_SECONDS = 10
30+
31+
32+
def _server_info_request():
33+
request = server_info_pb2.ServerInfoRequest()
34+
request.version = version.VERSION
35+
return request
36+
37+
38+
def fetch_server_info(origin):
39+
"""Fetches server info from a remote server.
40+
41+
Args:
42+
origin: The server with which to communicate. Should be a string
43+
like "https://tensorboard.dev", including protocol, host, and (if
44+
needed) port.
45+
46+
Returns:
47+
A `server_info_pb2.ServerInfoResponse` message.
48+
49+
Raises:
50+
CommunicationError: Upon failure to connect to or successfully
51+
communicate with the remote server.
52+
"""
53+
endpoint = "%s/api/uploader" % origin
54+
post_body = _server_info_request().SerializeToString()
55+
try:
56+
response = requests.post(
57+
endpoint, data=post_body, timeout=_REQUEST_TIMEOUT_SECONDS
58+
)
59+
except requests.RequestException as e:
60+
raise CommunicationError("Failed to connect to backend: %s" % e)
61+
if not response.ok:
62+
raise CommunicationError(
63+
"Non-OK status from backend (%d %s): %r"
64+
% (response.status_code, response.reason, response.content)
65+
)
66+
try:
67+
return server_info_pb2.ServerInfoResponse.FromString(response.content)
68+
except message.DecodeError as e:
69+
raise CommunicationError(
70+
"Corrupt response from backend (%s): %r" % (e, response.content)
71+
)
72+
73+
74+
def create_server_info(frontend_origin, api_endpoint):
75+
"""Manually creates server info given a frontend and backend.
76+
77+
Args:
78+
frontend_origin: The origin of the TensorBoard.dev frontend, like
79+
"https://tensorboard.dev" or "http://localhost:8000".
80+
api_endpoint: As to `server_info_pb2.ApiServer.endpoint`.
81+
82+
Returns:
83+
A `server_info_pb2.ServerInfoResponse` message.
84+
"""
85+
result = server_info_pb2.ServerInfoResponse()
86+
result.compatibility.verdict = server_info_pb2.VERDICT_OK
87+
result.api_server.endpoint = api_endpoint
88+
url_format = result.url_format
89+
placeholder = "{{EID}}"
90+
while placeholder in frontend_origin:
91+
placeholder = "{%s}" % placeholder
92+
url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder)
93+
url_format.id_placeholder = placeholder
94+
return result
95+
96+
97+
class CommunicationError(RuntimeError):
98+
"""Raised upon failure to communicate with the server."""
99+
100+
pass
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for tensorboard.uploader.server_info."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import errno
22+
import os
23+
import socket
24+
from wsgiref import simple_server
25+
26+
from concurrent import futures
27+
from werkzeug import wrappers
28+
29+
from tensorboard import test as tb_test
30+
from tensorboard import version
31+
from tensorboard.uploader import server_info
32+
from tensorboard.uploader.proto import server_info_pb2
33+
34+
35+
class FetchServerInfoTest(tb_test.TestCase):
36+
"""Tests for `fetch_server_info`."""
37+
38+
def _start_server(self, app):
39+
"""Starts a server and returns its origin ("http://localhost:PORT")."""
40+
(_, localhost) = _localhost()
41+
server_class = _make_ipv6_compatible_wsgi_server()
42+
server = simple_server.make_server(localhost, 0, app, server_class)
43+
executor = futures.ThreadPoolExecutor()
44+
future = executor.submit(server.serve_forever, poll_interval=0.01)
45+
46+
def cleanup():
47+
server.shutdown() # stop handling requests
48+
server.server_close() # release port
49+
future.result(timeout=3) # wait for server termination
50+
51+
self.addCleanup(cleanup)
52+
return "http://localhost:%d" % server.server_port
53+
54+
def test_fetches_response(self):
55+
expected_result = server_info_pb2.ServerInfoResponse()
56+
expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK
57+
expected_result.compatibility.details = "all clear"
58+
expected_result.api_server.endpoint = "api.example.com:443"
59+
expected_result.url_format.template = "http://localhost:8080/{{eid}}"
60+
expected_result.url_format.id_placeholder = "{{eid}}"
61+
62+
@wrappers.BaseRequest.application
63+
def app(request):
64+
self.assertEqual(request.method, "POST")
65+
self.assertEqual(request.path, "/api/uploader")
66+
body = request.get_data()
67+
request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
68+
self.assertEqual(request_pb.version, version.VERSION)
69+
return wrappers.BaseResponse(expected_result.SerializeToString())
70+
71+
origin = self._start_server(app)
72+
result = server_info.fetch_server_info(origin)
73+
self.assertEqual(result, expected_result)
74+
75+
def test_econnrefused(self):
76+
(family, localhost) = _localhost()
77+
s = socket.socket(family)
78+
s.bind((localhost, 0))
79+
self.addCleanup(s.close)
80+
port = s.getsockname()[1]
81+
with self.assertRaises(server_info.CommunicationError) as cm:
82+
server_info.fetch_server_info("http://localhost:%d" % port)
83+
msg = str(cm.exception)
84+
self.assertIn("Failed to connect to backend", msg)
85+
if os.name != "nt":
86+
self.assertIn(os.strerror(errno.ECONNREFUSED), msg)
87+
88+
def test_non_ok_response(self):
89+
@wrappers.BaseRequest.application
90+
def app(request):
91+
del request # unused
92+
return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway")
93+
94+
origin = self._start_server(app)
95+
with self.assertRaises(server_info.CommunicationError) as cm:
96+
server_info.fetch_server_info(origin)
97+
msg = str(cm.exception)
98+
self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg)
99+
self.assertIn("very sad", msg)
100+
101+
def test_corrupt_response(self):
102+
@wrappers.BaseRequest.application
103+
def app(request):
104+
del request # unused
105+
return wrappers.BaseResponse(b"an unlikely proto")
106+
107+
origin = self._start_server(app)
108+
with self.assertRaises(server_info.CommunicationError) as cm:
109+
server_info.fetch_server_info(origin)
110+
msg = str(cm.exception)
111+
self.assertIn("Corrupt response from backend", msg)
112+
self.assertIn("an unlikely proto", msg)
113+
114+
115+
class CreateServerInfoTest(tb_test.TestCase):
116+
"""Tests for `create_server_info`."""
117+
118+
def test(self):
119+
frontend = "http://localhost:8080"
120+
backend = "localhost:10000"
121+
result = server_info.create_server_info(frontend, backend)
122+
123+
expected_compatibility = server_info_pb2.Compatibility()
124+
expected_compatibility.verdict = server_info_pb2.VERDICT_OK
125+
expected_compatibility.details = ""
126+
self.assertEqual(result.compatibility, expected_compatibility)
127+
128+
expected_api_server = server_info_pb2.ApiServer()
129+
expected_api_server.endpoint = backend
130+
self.assertEqual(result.api_server, expected_api_server)
131+
132+
url_format = result.url_format
133+
actual_url = url_format.template.replace(url_format.id_placeholder, "123")
134+
expected_url = "http://localhost:8080/experiment/123/"
135+
self.assertEqual(actual_url, expected_url)
136+
137+
138+
def _localhost():
139+
"""Gets family and nodename for a loopback address."""
140+
s = socket
141+
infos = s.getaddrinfo(None, 0, s.AF_UNSPEC, s.SOCK_STREAM, 0, s.AI_ADDRCONFIG)
142+
(family, _, _, _, address) = infos[0]
143+
nodename = address[0]
144+
return (family, nodename)
145+
146+
147+
def _make_ipv6_compatible_wsgi_server():
148+
"""Creates a `WSGIServer` subclass that works on IPv6-only machines."""
149+
address_family = _localhost()[0]
150+
attrs = {"address_family": address_family}
151+
bases = (simple_server.WSGIServer, object) # `object` needed for py2
152+
return type("_Ipv6CompatibleWsgiServer", bases, attrs)
153+
154+
155+
if __name__ == "__main__":
156+
tb_test.main()

0 commit comments

Comments
 (0)