Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ py_library(
deps = [
":util",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"//tensorboard/util:grpc_util",
],
)

Expand All @@ -32,6 +33,7 @@ py_test(
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
"//tensorboard/util:grpc_util",
"@org_pythonhosted_mock",
],
)
Expand Down
8 changes: 6 additions & 2 deletions tensorboard/uploader/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from tensorboard.uploader.proto import export_service_pb2
from tensorboard.uploader import util
from tensorboard.util import grpc_util

# Characters that are assumed to be safe in filenames. Note that the
# server's experiment IDs are base64 encodings of 16-byte blobs, so they
Expand Down Expand Up @@ -120,7 +121,8 @@ def _request_experiment_ids(self, read_time):
"""Yields all of the calling user's experiment IDs, as strings."""
request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64)
util.set_timestamp(request.read_timestamp, read_time)
stream = self._api.StreamExperiments(request)
stream = self._api.StreamExperiments(
request, metadata=grpc_util.version_metadata())
for response in stream:
for experiment_id in response.experiment_ids:
yield experiment_id
Expand All @@ -136,7 +138,9 @@ def _request_scalar_data(self, experiment_id, read_time):
# IDs). Any non-transient errors would be internal, and we have no
# way to efficiently resume from transient errors because the server
# does not support pagination.
for response in self._api.StreamExperimentData(request):
stream = self._api.StreamExperimentData(
request, metadata=grpc_util.version_metadata())
for response in stream:
metadata = base64.b64encode(
response.tag_metadata.SerializeToString()).decode("ascii")
wall_times = [t.ToNanoseconds() / 1e9 for t in response.points.wall_times]
Expand Down
19 changes: 11 additions & 8 deletions tensorboard/uploader/exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tensorboard.uploader.proto import export_service_pb2_grpc
from tensorboard.uploader import exporter as exporter_lib
from tensorboard.uploader import test_util
from tensorboard.util import grpc_util
from tensorboard import test as tb_test
from tensorboard.compat.proto import summary_pb2

Expand All @@ -62,13 +63,15 @@ def test_e2e_success_case(self):
export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]),
])

def stream_experiments(request):
def stream_experiments(request, **kwargs):
del request # unused
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["123", "456"])
yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"])

def stream_experiment_data(request):
def stream_experiment_data(request, **kwargs):
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
for run in ("train", "test"):
for tag in ("accuracy", "loss"):
response = export_service_pb2.StreamExperimentDataResponse()
Expand Down Expand Up @@ -110,13 +113,13 @@ def stream_experiment_data(request):
expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
expected_eids_request.limit = 2**63 - 1
mock_api_client.StreamExperiments.assert_called_once_with(
expected_eids_request)
expected_eids_request, metadata=grpc_util.version_metadata())

expected_data_request = export_service_pb2.StreamExperimentDataRequest()
expected_data_request.experiment_id = "123"
expected_data_request.read_timestamp.CopyFrom(start_time_pb)
mock_api_client.StreamExperimentData.assert_called_once_with(
expected_data_request)
expected_data_request, metadata=grpc_util.version_metadata())

# The next iteration should just request data for the next experiment.
mock_api_client.StreamExperiments.reset_mock()
Expand All @@ -128,7 +131,7 @@ def stream_experiment_data(request):
mock_api_client.StreamExperiments.assert_not_called()
expected_data_request.experiment_id = "456"
mock_api_client.StreamExperimentData.assert_called_once_with(
expected_data_request)
expected_data_request, metadata=grpc_util.version_metadata())

# Again, request data for the next experiment; this experiment ID
# was in the second response batch in the list of IDs.
Expand All @@ -141,7 +144,7 @@ def stream_experiment_data(request):
mock_api_client.StreamExperiments.assert_not_called()
expected_data_request.experiment_id = "789"
mock_api_client.StreamExperimentData.assert_called_once_with(
expected_data_request)
expected_data_request, metadata=grpc_util.version_metadata())

# The final continuation shouldn't need to send any RPCs.
mock_api_client.StreamExperiments.reset_mock()
Expand Down Expand Up @@ -176,7 +179,7 @@ def stream_experiment_data(request):
def test_rejects_dangerous_experiment_ids(self):
mock_api_client = self._create_mock_api_client()

def stream_experiments(request):
def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["../authorized_keys"])
Expand Down Expand Up @@ -212,7 +215,7 @@ def test_rejects_existing_directory(self):
def test_rejects_existing_file(self):
mock_api_client = self._create_mock_api_client()

def stream_experiments(request):
def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"])

Expand Down
3 changes: 3 additions & 0 deletions tensorboard/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorboard:expect_grpc_installed",
"//tensorboard:version",
"//tensorboard/util:tb_logging",
],
)
Expand All @@ -70,7 +71,9 @@ py_test(
"//tensorboard:expect_futures_installed",
"//tensorboard:expect_grpc_installed",
"//tensorboard:test",
"//tensorboard:version",
"@org_pythonhosted_mock",
"@org_pythonhosted_six",
],
)

Expand Down
34 changes: 33 additions & 1 deletion tensorboard/util/grpc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import grpc

from tensorboard import version
from tensorboard.util import tb_logging

logger = tb_logging.get_logger()
Expand All @@ -46,6 +47,9 @@
grpc.StatusCode.UNAVAILABLE,
])

# gRPC metadata key whose value contains the client version.
_VERSION_METADATA_KEY = "tensorboard-version"


def call_with_retries(api_method, request, clock=None):
"""Call a gRPC stub API method, with automatic retry logic.
Expand Down Expand Up @@ -78,7 +82,10 @@ def call_with_retries(api_method, request, clock=None):
while True:
num_attempts += 1
try:
return api_method(request, timeout=_GRPC_DEFAULT_TIMEOUT_SECS)
return api_method(
request,
timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
metadata=version_metadata())
except grpc.RpcError as e:
logger.info("RPC call %s got error %s", rpc_name, e)
if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
Expand All @@ -92,3 +99,28 @@ def call_with_retries(api_method, request, clock=None):
"RPC call %s attempted %d times, retrying in %.1f seconds",
rpc_name, num_attempts, backoff_secs)
clock.sleep(backoff_secs)


def version_metadata():
"""Creates gRPC invocation metadata encoding the TensorBoard version.

Usage: `stub.MyRpc(request, metadata=version_metadata())`.

Returns:
A tuple of key-value pairs (themselves 2-tuples) to be passed as the
`metadata` kwarg to gRPC stub API methods.
"""
return ((_VERSION_METADATA_KEY, version.VERSION),)


def extract_version(metadata):
"""Extracts version from invocation metadata.

The argument should be the result of a prior call to `metadata` or the
result of combining such a result with other metadata.

Returns:
The TensorBoard version listed in this metadata, or `None` if none
is listed.
"""
return dict(metadata).get(_VERSION_METADATA_KEY)
37 changes: 36 additions & 1 deletion tensorboard/util/grpc_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
from __future__ import print_function

import contextlib
import hashlib
import threading

from concurrent import futures
import grpc
import six

from tensorboard.util import grpc_util
from tensorboard.util import grpc_util_test_pb2
from tensorboard.util import grpc_util_test_pb2_grpc
from tensorboard.util import test_util
from tensorboard import test as tb_test
from tensorboard import version


def make_request(nonce):
Expand Down Expand Up @@ -69,7 +72,7 @@ def launch_server():
thread.join()


class GrpcUtilTest(tb_test.TestCase):
class CallWithRetriesTest(tb_test.TestCase):

def test_call_with_retries_succeeds(self):
def handler(request, _):
Expand Down Expand Up @@ -124,6 +127,38 @@ def handler(request, context):
self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)

def test_call_with_retries_includes_version_metadata(self):
def digest(s):
"""Hashes a string into a 32-bit integer."""
return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) & 0xffffffff
def handler(request, context):
metadata = context.invocation_metadata()
client_version = grpc_util.extract_version(metadata)
return make_response(digest(client_version))
server = TestGrpcServer(handler)
with server.run() as client:
response = grpc_util.call_with_retries(client.TestRpc, make_request(0))
expected_nonce = digest(
grpc_util.extract_version(grpc_util.version_metadata()))
self.assertEqual(make_response(expected_nonce), response)


class VersionMetadataTest(tb_test.TestCase):

def test_structure(self):
result = grpc_util.version_metadata()
self.assertIsInstance(result, tuple)
for kv in result:
self.assertIsInstance(kv, tuple)
self.assertLen(kv, 2)
(k, v) = kv
self.assertIsInstance(k, str)
self.assertIsInstance(v, six.string_types)

def test_roundtrip(self):
result = grpc_util.extract_version(grpc_util.version_metadata())
self.assertEqual(result, version.VERSION)


if __name__ == "__main__":
tb_test.main()