diff --git a/tensorboard/BUILD b/tensorboard/BUILD index df1c16a0f2..16dc32db50 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -302,6 +302,14 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "expect_grpc_installed", + # This is a dummy rule used as a grpc dependency in open-source. + # We expect grpc to already be installed on the system, e.g. via + # `pip install grpcio` + visibility = ["//visibility:public"], +) + py_library( name = "expect_sqlite3_installed", # This is a dummy rule used as a sqlite3 dependency in open-source. diff --git a/tensorboard/util/BUILD b/tensorboard/util/BUILD index c30c548ecb..32882c37b8 100644 --- a/tensorboard/util/BUILD +++ b/tensorboard/util/BUILD @@ -1,5 +1,7 @@ package(default_visibility = ["//tensorboard:internal"]) +load("//tensorboard/defs:protos.bzl", "tb_proto_library") + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) # Needed for internal repo. @@ -45,6 +47,40 @@ py_test( ], ) +py_library( + name = "grpc_util", + srcs = ["grpc_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorboard:expect_grpc_installed", + "//tensorboard/util:tb_logging", + ], +) + +py_test( + name = "grpc_util_test", + size = "small", + srcs = ["grpc_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":grpc_util", + ":grpc_util_test_proto_py_pb2", + ":grpc_util_test_proto_py_pb2_grpc", + ":test_util", + "//tensorboard:expect_futures_installed", + "//tensorboard:expect_grpc_installed", + "//tensorboard:test", + "@org_pythonhosted_mock", + ], +) + +tb_proto_library( + name = "grpc_util_test_proto", + has_services = True, + srcs = ["grpc_util_test.proto"], + testonly = True, +) + py_library( name = "op_evaluator", srcs = ["op_evaluator.py"], diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py new file mode 100644 index 0000000000..caa66a933d --- /dev/null +++ b/tensorboard/util/grpc_util.py @@ -0,0 +1,94 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for working with python gRPC stubs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +import grpc + +from tensorboard.util import tb_logging + +logger = tb_logging.get_logger() + +# Default RPC timeout. +_GRPC_DEFAULT_TIMEOUT_SECS = 30 + +# Max number of times to attempt an RPC, retrying on transient failures. +_GRPC_RETRY_MAX_ATTEMPTS = 5 + +# Parameters to control the exponential backoff behavior. +_GRPC_RETRY_EXPONENTIAL_BASE = 2 +_GRPC_RETRY_JITTER_FACTOR_MIN = 1.1 +_GRPC_RETRY_JITTER_FACTOR_MAX = 1.5 + +# Status codes from gRPC for which it's reasonable to retry the RPC. +_GRPC_RETRYABLE_STATUS_CODES = frozenset([ + grpc.StatusCode.ABORTED, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.UNAVAILABLE, +]) + + +def call_with_retries(api_method, request, clock=None): + """Call a gRPC stub API method, with automatic retry logic. + + This only supports unary-unary RPCs: i.e., no streaming on either end. + Streamed RPCs will generally need application-level pagination support, + because after a gRPC error one must retry the entire request; there is no + "retry-resume" functionality. + + Args: + api_method: Callable for the API method to invoke. + request: Request protocol buffer to pass to the API method. + clock: an interface object supporting `time()` and `sleep()` methods + like the standard `time` module; if not passed, uses the normal module. + + Returns: + Response protocol buffer returned by the API method. + + Raises: + grpc.RpcError: if a non-retryable error is returned, or if all retry + attempts have been exhausted. + """ + if clock is None: + clock = time + # We can't actually use api_method.__name__ because it's not a real method, + # it's a special gRPC callable instance that doesn't expose the method name. + rpc_name = request.__class__.__name__.replace("Request", "") + logger.debug("RPC call %s with request: %r", rpc_name, request) + num_attempts = 0 + while True: + num_attempts += 1 + try: + return api_method(request, timeout=_GRPC_DEFAULT_TIMEOUT_SECS) + except grpc.RpcError as e: + logger.info("RPC call %s got error %s", rpc_name, e) + if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: + raise + if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: + raise + jitter_factor = random.uniform( + _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX) + backoff_secs = (_GRPC_RETRY_EXPONENTIAL_BASE**num_attempts) * jitter_factor + logger.info( + "RPC call %s attempted %d times, retrying in %.1f seconds", + rpc_name, num_attempts, backoff_secs) + clock.sleep(backoff_secs) diff --git a/tensorboard/util/grpc_util_test.proto b/tensorboard/util/grpc_util_test.proto new file mode 100644 index 0000000000..3f4bb5c976 --- /dev/null +++ b/tensorboard/util/grpc_util_test.proto @@ -0,0 +1,18 @@ +// Minimal example RPC service definition. See grpc_util_test.py for usage. +syntax = "proto3"; + +package tensorboard.util; + +// Test service for grpc_util_test.py. +service TestService { + // Test RPC. + rpc TestRpc(TestRpcRequest) returns (TestRpcResponse); +} + +message TestRpcRequest { + int32 nonce = 1; +} + +message TestRpcResponse { + int32 nonce = 1; +} diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py new file mode 100644 index 0000000000..df0f8e4d03 --- /dev/null +++ b/tensorboard/util/grpc_util_test.py @@ -0,0 +1,129 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tensorboard.util.grpc_util`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import threading + +from concurrent import futures +import grpc + +from tensorboard.util import grpc_util +from tensorboard.util import grpc_util_test_pb2 +from tensorboard.util import grpc_util_test_pb2_grpc +from tensorboard.util import test_util +from tensorboard import test as tb_test + + +def make_request(nonce): + return grpc_util_test_pb2.TestRpcRequest(nonce=nonce) + + +def make_response(nonce): + return grpc_util_test_pb2.TestRpcResponse(nonce=nonce) + + +class TestGrpcServer(grpc_util_test_pb2_grpc.TestServiceServicer): + """Helper for testing gRPC client logic with a dummy gRPC server.""" + + def __init__(self, handler): + super(TestGrpcServer, self).__init__() + self._handler = handler + + def TestRpc(self, request, context): + return self._handler(request, context) + + @contextlib.contextmanager + def run(self): + """Context manager to run the gRPC server and yield a client for it.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server) + port = server.add_secure_port( + "localhost:0", grpc.local_server_credentials()) + def launch_server(): + server.start() + server.wait_for_termination() + thread = threading.Thread(target=launch_server, name="TestGrpcServer") + thread.daemon = True + thread.start() + with grpc.secure_channel( + "localhost:%d" % port, grpc.local_channel_credentials()) as channel: + yield grpc_util_test_pb2_grpc.TestServiceStub(channel) + server.stop(grace=None) + thread.join() + + +class GrpcUtilTest(tb_test.TestCase): + + def test_call_with_retries_succeeds(self): + def handler(request, _): + return make_response(request.nonce) + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries(client.TestRpc, make_request(42)) + self.assertEqual(make_response(42), response) + + def test_call_with_retries_fails_immediately_on_permanent_error(self): + def handler(_, context): + context.abort(grpc.StatusCode.INTERNAL, "foo") + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries(client.TestRpc, make_request(42)) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + + def test_call_with_retries_fails_after_backoff_on_nonpermanent_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + def handler(_, context): + attempt_times.append(fake_time.time()) + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries(client.TestRpc, make_request(42), fake_time) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + + def test_call_with_retries_succeeds_after_backoff_on_transient_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + def handler(request, context): + attempt_times.append(fake_time.time()) + if len(attempt_times) < 3: + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + return make_response(request.nonce) + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries( + client.TestRpc, make_request(42), fake_time) + self.assertEqual(make_response(42), response) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/util/test_util.py b/tensorboard/util/test_util.py index c970dcf472..44d56fa164 100644 --- a/tensorboard/util/test_util.py +++ b/tensorboard/util/test_util.py @@ -128,6 +128,22 @@ def get(logdir): return FileWriterCache._cache[logdir] +class FakeTime(object): + """Thread-safe fake replacement for the `time` module.""" + + def __init__(self, current=0.0): + self._time = float(current) + self._lock = threading.Lock() + + def time(self): + with self._lock: + return self._time + + def sleep(self, secs): + with self._lock: + self._time += secs + + def ensure_tb_summary_proto(summary): """Ensures summary is TensorBoard Summary proto.