diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 45a55d3ed9..7e5fce2179 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -16,7 +16,9 @@ import enum +import functools import random +import threading import time import grpc @@ -51,6 +53,148 @@ _VERSION_METADATA_KEY = "tensorboard-version" +class AsyncCallFuture: + """Encapsulates the future value of a retriable async gRPC request. + + Abstracts over the set of futures returned by a set of gRPC calls + comprising a single logical gRPC request with retries. Communicates + to the caller the result or exception resulting from the request. + + Args: + completion_event: The constructor should provide a `threding.Event` which + will be used to communicate when the set of gRPC requests is complete. + """ + + def __init__(self, completion_event): + self._active_grpc_future = None + self._active_grpc_future_lock = threading.Lock() + self._completion_event = completion_event + + def _set_active_future(self, grpc_future): + if grpc_future is None: + raise RuntimeError( + "_set_active_future invoked with grpc_future=None." + ) + with self._active_grpc_future_lock: + self._active_grpc_future = grpc_future + + def result(self, timeout): + """Analogous to `grpc.Future.result`. Returns the value or exception. + + This method will wait until the full set of gRPC requests is complete + and then act as `grpc.Future.result` for the single gRPC invocation + corresponding to the first successful call or final failure, as + appropriate. + + Args: + timeout: How long to wait in seconds before giving up and raising. + + Returns: + The result of the future corresponding to the single gRPC + corresponding to the successful call. + + Raises: + * `grpc.FutureTimeoutError` if timeout seconds elapse before the gRPC + calls could complete, including waits and retries. + * The exception corresponding to the last non-retryable gRPC request + in the case that a successful gRPC request was not made. + """ + if not self._completion_event.wait(timeout): + raise grpc.FutureTimeoutError( + f"AsyncCallFuture timed out after {timeout} seconds" + ) + with self._active_grpc_future_lock: + if self._active_grpc_future is None: + raise RuntimeError("AsyncFuture never had an active future set") + return self._active_grpc_future.result() + + +def async_call_with_retries(api_method, request, clock=None): + """Initiate an asynchronous call to a gRPC stub, with retry logic. + + This is similar to the `async_call` API, except that the call is handled + asynchronously, and the completion may be handled by another thread. The + caller must provide a `done_callback` argument which will handle the + result or exception rising from the gRPC completion. + + Retries are handled with jittered exponential backoff to spread out failures + due to request spikes. + + This only supports unary-unary RPCs: i.e., no streaming on either end. + + Args: + api_method: Callable for the API method to invoke. + request: Request protocol buffer to pass to the API method. + clock: an interface object supporting `time()` and `sleep()` methods + like the standard `time` module; if not passed, uses the normal module. + + Returns: + An `AsyncCallFuture` which will encapsulate the `grpc.Future` + corresponding to the gRPC call which either completes successfully or + represents the final try. + """ + if clock is None: + clock = time + logger.debug("Async RPC call %s with request: %r", api_method, request) + + completion_event = threading.Event() + async_future = AsyncCallFuture(completion_event) + + def async_call(handler): + """Invokes the gRPC future and orchestrates it via the AsyncCallFuture.""" + future = api_method.future( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata(), + ) + # Ensure we set the active future before invoking the done callback, to + # avoid the case where the done callback completes immediately and + # triggers completion event while async_future still holds the old + # future. + async_future._set_active_future(future) + future.add_done_callback(handler) + + # retry_handler is the continuation of the `async_call`. It should: + # * If the grpc call succeeds: trigger the `completion_event`. + # * If there are no more retries: trigger the `completion_event`. + # * Otherwise, invoke a new async_call with the same + # retry_handler. + def retry_handler(future, num_attempts): + e = future.exception() + if e is None: + completion_event.set() + return + else: + logger.info("RPC call %s got error %s", api_method, e) + # If unable to retry, proceed to completion. + if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: + completion_event.set() + return + if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: + completion_event.set() + return + # If able to retry, wait then do so. + backoff_secs = _compute_backoff_seconds(num_attempts) + clock.sleep(backoff_secs) + async_call( + functools.partial(retry_handler, num_attempts=num_attempts + 1) + ) + + async_call(functools.partial(retry_handler, num_attempts=1)) + return async_future + + +def _compute_backoff_seconds(num_attempts): + """Compute appropriate wait time between RPC attempts.""" + jitter_factor = random.uniform( + _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX + ) + backoff_secs = ( + _GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts + ) * jitter_factor + return backoff_secs + + def call_with_retries(api_method, request, clock=None): """Call a gRPC stub API method, with automatic retry logic. @@ -59,6 +203,9 @@ def call_with_retries(api_method, request, clock=None): because after a gRPC error one must retry the entire request; there is no "retry-resume" functionality. + Retries are handled with jittered exponential backoff to spread out failures + due to request spikes. + Args: api_method: Callable for the API method to invoke. request: Request protocol buffer to pass to the API method. @@ -93,12 +240,7 @@ def call_with_retries(api_method, request, clock=None): raise if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: raise - jitter_factor = random.uniform( - _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX - ) - backoff_secs = ( - _GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts - ) * jitter_factor + backoff_secs = _compute_backoff_seconds(num_attempts) logger.info( "RPC call %s attempted %d times, retrying in %.1f seconds", rpc_name, diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index efec8a41a0..dc566401cf 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -18,6 +18,7 @@ import contextlib import hashlib import threading +import time from concurrent import futures import grpc @@ -161,6 +162,120 @@ def handler(request, context): self.assertEqual(make_response(expected_nonce), response) +class AsyncCallWithRetriesTest(tb_test.TestCase): + def test_aync_call_with_retries_succeeds(self): + # Setup: Basic server, echos input. + def handler(request, _): + return make_response(request.nonce) + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + future = grpc_util.async_call_with_retries( + client.TestRpc, make_request(42) + ) + # Verify the correct value has been returned in the future. + self.assertEqual(make_response(42), future.result(2)) + + def test_aync_call_raises_at_timeout(self): + # Setup: Server waits 0.5 seconds before echoing. + def handler(request, _): + time.sleep(0.5) + return make_response(request.nonce) + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + future = grpc_util.async_call_with_retries( + client.TestRpc, make_request(42) + ) + # Request the result in 0.01s, critically less time than the server + # will take to respond. Verify that request will cause an + # appropriate exception. + with self.assertRaisesRegex(grpc.FutureTimeoutError, "timed out"): + future.result(0.01) + + def test_async_call_with_retries_fails_immediately_on_permanent_error(self): + # Setup: Server which fails with an ISE. + def handler(_, context): + context.abort(grpc.StatusCode.INTERNAL, "death_ray") + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` + future = grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + ) + # Expect that the future raises an Exception which is the + # right type and carries the right message. + with self.assertRaises(grpc.RpcError) as raised: + future.result(2) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) + self.assertEqual("death_ray", raised.exception.details()) + + def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + # Setup: Server which always fails with an UNAVAILABLE error. + def handler(_, context): + attempt_times.append(fake_time.time()) + context.abort( + grpc.StatusCode.UNAVAILABLE, f"just a sec {len(attempt_times)}." + ) + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` against the scripted server. + future = grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + clock=fake_time, + ) + # Expect that the future raises an Exception which is the right + # type and carries the right message. + with self.assertRaises(grpc.RpcError) as raised: + future.result(2) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, raised.exception.code() + ) + self.assertEqual("just a sec 5.", raised.exception.details()) + # Verify the number of attempts and delays between them. + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + + def test_async_with_retries_succeeds_after_backoff_on_transient_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + # Setup: Server which passes on the third attempt. + def handler(request, context): + attempt_times.append(fake_time.time()) + if len(attempt_times) < 3: + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + return make_response(request.nonce) + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` against the scripted server. + future = grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + clock=fake_time, + ) + # Expect: + # 1) The response contains the expected value. + # 2) The number of attempts and delays between them. + self.assertEqual(make_response(42), future.result(2)) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + + class VersionMetadataTest(tb_test.TestCase): def test_structure(self): result = grpc_util.version_metadata()