Skip to content

Commit cd8a471

Browse files
authored
uploader: include TensorBoard version in RPCs (#2843)
Summary: This will enable servers to detect when clients are using an old version and send instructions to show the user how to upgrade, especially in the case of known bugs. We could implement this with gRPC’s client-side interceptor API, but that API is loudly disclaimed as experimental, so I’ve gone with the manual plumbing for now. Test Plan: Ran the new uploader with an existing server and verified that it still works (extra metadata is safely ignored). Verified that a server can be modified to read this metadata for all three flows (upload, delete, export). Tested in both Python 2 and Python 3. wchargin-branch: uploader-version
1 parent 1a5d8c3 commit cd8a471

File tree

6 files changed

+91
-12
lines changed

6 files changed

+91
-12
lines changed

tensorboard/uploader/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ py_library(
1818
deps = [
1919
":util",
2020
"//tensorboard/uploader/proto:protos_all_py_pb2",
21+
"//tensorboard/util:grpc_util",
2122
],
2223
)
2324

@@ -32,6 +33,7 @@ py_test(
3233
"//tensorboard/compat/proto:protos_all_py_pb2",
3334
"//tensorboard/uploader/proto:protos_all_py_pb2",
3435
"//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
36+
"//tensorboard/util:grpc_util",
3537
"@org_pythonhosted_mock",
3638
],
3739
)

tensorboard/uploader/exporter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from tensorboard.uploader.proto import export_service_pb2
2929
from tensorboard.uploader import util
30+
from tensorboard.util import grpc_util
3031

3132
# Characters that are assumed to be safe in filenames. Note that the
3233
# server's experiment IDs are base64 encodings of 16-byte blobs, so they
@@ -120,7 +121,8 @@ def _request_experiment_ids(self, read_time):
120121
"""Yields all of the calling user's experiment IDs, as strings."""
121122
request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64)
122123
util.set_timestamp(request.read_timestamp, read_time)
123-
stream = self._api.StreamExperiments(request)
124+
stream = self._api.StreamExperiments(
125+
request, metadata=grpc_util.version_metadata())
124126
for response in stream:
125127
for experiment_id in response.experiment_ids:
126128
yield experiment_id
@@ -136,7 +138,9 @@ def _request_scalar_data(self, experiment_id, read_time):
136138
# IDs). Any non-transient errors would be internal, and we have no
137139
# way to efficiently resume from transient errors because the server
138140
# does not support pagination.
139-
for response in self._api.StreamExperimentData(request):
141+
stream = self._api.StreamExperimentData(
142+
request, metadata=grpc_util.version_metadata())
143+
for response in stream:
140144
metadata = base64.b64encode(
141145
response.tag_metadata.SerializeToString()).decode("ascii")
142146
wall_times = [t.ToNanoseconds() / 1e9 for t in response.points.wall_times]

tensorboard/uploader/exporter_test.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensorboard.uploader.proto import export_service_pb2_grpc
3737
from tensorboard.uploader import exporter as exporter_lib
3838
from tensorboard.uploader import test_util
39+
from tensorboard.util import grpc_util
3940
from tensorboard import test as tb_test
4041
from tensorboard.compat.proto import summary_pb2
4142

@@ -62,13 +63,15 @@ def test_e2e_success_case(self):
6263
export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]),
6364
])
6465

65-
def stream_experiments(request):
66+
def stream_experiments(request, **kwargs):
6667
del request # unused
68+
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
6769
yield export_service_pb2.StreamExperimentsResponse(
6870
experiment_ids=["123", "456"])
6971
yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"])
7072

71-
def stream_experiment_data(request):
73+
def stream_experiment_data(request, **kwargs):
74+
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
7275
for run in ("train", "test"):
7376
for tag in ("accuracy", "loss"):
7477
response = export_service_pb2.StreamExperimentDataResponse()
@@ -110,13 +113,13 @@ def stream_experiment_data(request):
110113
expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
111114
expected_eids_request.limit = 2**63 - 1
112115
mock_api_client.StreamExperiments.assert_called_once_with(
113-
expected_eids_request)
116+
expected_eids_request, metadata=grpc_util.version_metadata())
114117

115118
expected_data_request = export_service_pb2.StreamExperimentDataRequest()
116119
expected_data_request.experiment_id = "123"
117120
expected_data_request.read_timestamp.CopyFrom(start_time_pb)
118121
mock_api_client.StreamExperimentData.assert_called_once_with(
119-
expected_data_request)
122+
expected_data_request, metadata=grpc_util.version_metadata())
120123

121124
# The next iteration should just request data for the next experiment.
122125
mock_api_client.StreamExperiments.reset_mock()
@@ -128,7 +131,7 @@ def stream_experiment_data(request):
128131
mock_api_client.StreamExperiments.assert_not_called()
129132
expected_data_request.experiment_id = "456"
130133
mock_api_client.StreamExperimentData.assert_called_once_with(
131-
expected_data_request)
134+
expected_data_request, metadata=grpc_util.version_metadata())
132135

133136
# Again, request data for the next experiment; this experiment ID
134137
# was in the second response batch in the list of IDs.
@@ -141,7 +144,7 @@ def stream_experiment_data(request):
141144
mock_api_client.StreamExperiments.assert_not_called()
142145
expected_data_request.experiment_id = "789"
143146
mock_api_client.StreamExperimentData.assert_called_once_with(
144-
expected_data_request)
147+
expected_data_request, metadata=grpc_util.version_metadata())
145148

146149
# The final continuation shouldn't need to send any RPCs.
147150
mock_api_client.StreamExperiments.reset_mock()
@@ -176,7 +179,7 @@ def stream_experiment_data(request):
176179
def test_rejects_dangerous_experiment_ids(self):
177180
mock_api_client = self._create_mock_api_client()
178181

179-
def stream_experiments(request):
182+
def stream_experiments(request, **kwargs):
180183
del request # unused
181184
yield export_service_pb2.StreamExperimentsResponse(
182185
experiment_ids=["../authorized_keys"])
@@ -212,7 +215,7 @@ def test_rejects_existing_directory(self):
212215
def test_rejects_existing_file(self):
213216
mock_api_client = self._create_mock_api_client()
214217

215-
def stream_experiments(request):
218+
def stream_experiments(request, **kwargs):
216219
del request # unused
217220
yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"])
218221

tensorboard/util/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ py_library(
5353
srcs_version = "PY2AND3",
5454
deps = [
5555
"//tensorboard:expect_grpc_installed",
56+
"//tensorboard:version",
5657
"//tensorboard/util:tb_logging",
5758
],
5859
)
@@ -70,7 +71,9 @@ py_test(
7071
"//tensorboard:expect_futures_installed",
7172
"//tensorboard:expect_grpc_installed",
7273
"//tensorboard:test",
74+
"//tensorboard:version",
7375
"@org_pythonhosted_mock",
76+
"@org_pythonhosted_six",
7477
],
7578
)
7679

tensorboard/util/grpc_util.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import grpc
2525

26+
from tensorboard import version
2627
from tensorboard.util import tb_logging
2728

2829
logger = tb_logging.get_logger()
@@ -46,6 +47,9 @@
4647
grpc.StatusCode.UNAVAILABLE,
4748
])
4849

50+
# gRPC metadata key whose value contains the client version.
51+
_VERSION_METADATA_KEY = "tensorboard-version"
52+
4953

5054
def call_with_retries(api_method, request, clock=None):
5155
"""Call a gRPC stub API method, with automatic retry logic.
@@ -78,7 +82,10 @@ def call_with_retries(api_method, request, clock=None):
7882
while True:
7983
num_attempts += 1
8084
try:
81-
return api_method(request, timeout=_GRPC_DEFAULT_TIMEOUT_SECS)
85+
return api_method(
86+
request,
87+
timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
88+
metadata=version_metadata())
8289
except grpc.RpcError as e:
8390
logger.info("RPC call %s got error %s", rpc_name, e)
8491
if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
@@ -92,3 +99,28 @@ def call_with_retries(api_method, request, clock=None):
9299
"RPC call %s attempted %d times, retrying in %.1f seconds",
93100
rpc_name, num_attempts, backoff_secs)
94101
clock.sleep(backoff_secs)
102+
103+
104+
def version_metadata():
105+
"""Creates gRPC invocation metadata encoding the TensorBoard version.
106+
107+
Usage: `stub.MyRpc(request, metadata=version_metadata())`.
108+
109+
Returns:
110+
A tuple of key-value pairs (themselves 2-tuples) to be passed as the
111+
`metadata` kwarg to gRPC stub API methods.
112+
"""
113+
return ((_VERSION_METADATA_KEY, version.VERSION),)
114+
115+
116+
def extract_version(metadata):
117+
"""Extracts version from invocation metadata.
118+
119+
The argument should be the result of a prior call to `metadata` or the
120+
result of combining such a result with other metadata.
121+
122+
Returns:
123+
The TensorBoard version listed in this metadata, or `None` if none
124+
is listed.
125+
"""
126+
return dict(metadata).get(_VERSION_METADATA_KEY)

tensorboard/util/grpc_util_test.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@
1919
from __future__ import print_function
2020

2121
import contextlib
22+
import hashlib
2223
import threading
2324

2425
from concurrent import futures
2526
import grpc
27+
import six
2628

2729
from tensorboard.util import grpc_util
2830
from tensorboard.util import grpc_util_test_pb2
2931
from tensorboard.util import grpc_util_test_pb2_grpc
3032
from tensorboard.util import test_util
3133
from tensorboard import test as tb_test
34+
from tensorboard import version
3235

3336

3437
def make_request(nonce):
@@ -69,7 +72,7 @@ def launch_server():
6972
thread.join()
7073

7174

72-
class GrpcUtilTest(tb_test.TestCase):
75+
class CallWithRetriesTest(tb_test.TestCase):
7376

7477
def test_call_with_retries_succeeds(self):
7578
def handler(request, _):
@@ -124,6 +127,38 @@ def handler(request, context):
124127
self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
125128
self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)
126129

130+
def test_call_with_retries_includes_version_metadata(self):
131+
def digest(s):
132+
"""Hashes a string into a 32-bit integer."""
133+
return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) & 0xffffffff
134+
def handler(request, context):
135+
metadata = context.invocation_metadata()
136+
client_version = grpc_util.extract_version(metadata)
137+
return make_response(digest(client_version))
138+
server = TestGrpcServer(handler)
139+
with server.run() as client:
140+
response = grpc_util.call_with_retries(client.TestRpc, make_request(0))
141+
expected_nonce = digest(
142+
grpc_util.extract_version(grpc_util.version_metadata()))
143+
self.assertEqual(make_response(expected_nonce), response)
144+
145+
146+
class VersionMetadataTest(tb_test.TestCase):
147+
148+
def test_structure(self):
149+
result = grpc_util.version_metadata()
150+
self.assertIsInstance(result, tuple)
151+
for kv in result:
152+
self.assertIsInstance(kv, tuple)
153+
self.assertLen(kv, 2)
154+
(k, v) = kv
155+
self.assertIsInstance(k, str)
156+
self.assertIsInstance(v, six.string_types)
157+
158+
def test_roundtrip(self):
159+
result = grpc_util.extract_version(grpc_util.version_metadata())
160+
self.assertEqual(result, version.VERSION)
161+
127162

128163
if __name__ == "__main__":
129164
tb_test.main()

0 commit comments

Comments
 (0)