Skip to content
154 changes: 148 additions & 6 deletions tensorboard/util/grpc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


import enum
import functools
import random
import threading
import time

import grpc
Expand Down Expand Up @@ -51,6 +53,148 @@
_VERSION_METADATA_KEY = "tensorboard-version"


class AsyncCallFuture:
"""Encapsulates the future value of a retriable async gRPC request.

Abstracts over the set of futures returned by a set of gRPC calls
comprising a single logical gRPC request with retries. Communicates
to the caller the result or exception resulting from the request.

Args:
completion_event: The constructor should provide a `threding.Event` which
will be used to communicate when the set of gRPC requests is complete.
"""

def __init__(self, completion_event):
self._active_grpc_future = None
self._active_grpc_future_lock = threading.Lock()
self._completion_event = completion_event

def _set_active_future(self, grpc_future):
if grpc_future is None:
raise RuntimeError(
"_set_active_future invoked with grpc_future=None."
)
with self._active_grpc_future_lock:
self._active_grpc_future = grpc_future

def result(self, timeout):
"""Analogous to `grpc.Future.result`. Returns the value or exception.

This method will wait until the full set of gRPC requests is complete
and then act as `grpc.Future.result` for the single gRPC invocation
corresponding to the first successful call or final failure, as
appropriate.

Args:
timeout: How long to wait in seconds before giving up and raising.

Returns:
The result of the future corresponding to the single gRPC
corresponding to the successful call.

Raises:
* `grpc.FutureTimeoutError` if timeout seconds elapse before the gRPC
calls could complete, including waits and retries.
* The exception corresponding to the last non-retryable gRPC request
in the case that a successful gRPC request was not made.
"""
if not self._completion_event.wait(timeout):
raise grpc.FutureTimeoutError(
f"AsyncCallFuture timed out after {timeout} seconds"
)
with self._active_grpc_future_lock:
if self._active_grpc_future is None:
raise RuntimeError("AsyncFuture never had an active future set")
return self._active_grpc_future.result()


def async_call_with_retries(api_method, request, clock=None):
"""Initiate an asynchronous call to a gRPC stub, with retry logic.

This is similar to the `async_call` API, except that the call is handled
asynchronously, and the completion may be handled by another thread. The
caller must provide a `done_callback` argument which will handle the
result or exception rising from the gRPC completion.

Retries are handled with jittered exponential backoff to spread out failures
due to request spikes.

This only supports unary-unary RPCs: i.e., no streaming on either end.

Args:
api_method: Callable for the API method to invoke.
request: Request protocol buffer to pass to the API method.
clock: an interface object supporting `time()` and `sleep()` methods
like the standard `time` module; if not passed, uses the normal module.

Returns:
An `AsyncCallFuture` which will encapsulate the `grpc.Future`
corresponding to the gRPC call which either completes successfully or
represents the final try.
"""
if clock is None:
clock = time
logger.debug("Async RPC call %s with request: %r", api_method, request)

completion_event = threading.Event()
async_future = AsyncCallFuture(completion_event)

def async_call(handler):
"""Invokes the gRPC future and orchestrates it via the AsyncCallFuture."""
future = api_method.future(
request,
timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
metadata=version_metadata(),
)
# Ensure we set the active future before invoking the done callback, to
# avoid the case where the done callback completes immediately and
# triggers completion event while async_future still holds the old
# future.
async_future._set_active_future(future)
future.add_done_callback(handler)

# retry_handler is the continuation of the `async_call`. It should:
# * If the grpc call succeeds: trigger the `completion_event`.
# * If there are no more retries: trigger the `completion_event`.
# * Otherwise, invoke a new async_call with the same
# retry_handler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"...same retry_handler, but incrementing the number of attempts." ?

def retry_handler(future, num_attempts):
e = future.exception()
if e is None:
completion_event.set()
return
else:
logger.info("RPC call %s got error %s", api_method, e)
# If unable to retry, proceed to completion.
if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
completion_event.set()
return
if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS:
completion_event.set()
return
# If able to retry, wait then do so.
backoff_secs = _compute_backoff_seconds(num_attempts)
clock.sleep(backoff_secs)
async_call(
functools.partial(retry_handler, num_attempts=num_attempts + 1)
)

async_call(functools.partial(retry_handler, num_attempts=1))
return async_future


def _compute_backoff_seconds(num_attempts):
"""Compute appropriate wait time between RPC attempts."""
jitter_factor = random.uniform(
_GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX
)
backoff_secs = (
_GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts
) * jitter_factor
return backoff_secs


def call_with_retries(api_method, request, clock=None):
"""Call a gRPC stub API method, with automatic retry logic.

Expand All @@ -59,6 +203,9 @@ def call_with_retries(api_method, request, clock=None):
because after a gRPC error one must retry the entire request; there is no
"retry-resume" functionality.

Retries are handled with jittered exponential backoff to spread out failures
due to request spikes.

Args:
api_method: Callable for the API method to invoke.
request: Request protocol buffer to pass to the API method.
Expand Down Expand Up @@ -93,12 +240,7 @@ def call_with_retries(api_method, request, clock=None):
raise
if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS:
raise
jitter_factor = random.uniform(
_GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX
)
backoff_secs = (
_GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts
) * jitter_factor
backoff_secs = _compute_backoff_seconds(num_attempts)
logger.info(
"RPC call %s attempted %d times, retrying in %.1f seconds",
rpc_name,
Expand Down
115 changes: 115 additions & 0 deletions tensorboard/util/grpc_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import contextlib
import hashlib
import threading
import time

from concurrent import futures
import grpc
Expand Down Expand Up @@ -161,6 +162,120 @@ def handler(request, context):
self.assertEqual(make_response(expected_nonce), response)


class AsyncCallWithRetriesTest(tb_test.TestCase):
def test_aync_call_with_retries_succeeds(self):
# Setup: Basic server, echos input.
def handler(request, _):
return make_response(request.nonce)

server = TestGrpcServer(handler)
with server.run() as client:
# Execute `async_call_with_retries` with the callback.
future = grpc_util.async_call_with_retries(
client.TestRpc, make_request(42)
)
# Verify the correct value has been returned in the future.
self.assertEqual(make_response(42), future.result(2))

def test_aync_call_raises_at_timeout(self):
# Setup: Server waits 0.5 seconds before echoing.
def handler(request, _):
time.sleep(0.5)
return make_response(request.nonce)

server = TestGrpcServer(handler)
with server.run() as client:
# Execute `async_call_with_retries` with the callback.
future = grpc_util.async_call_with_retries(
client.TestRpc, make_request(42)
)
# Request the result in 0.01s, critically less time than the server
# will take to respond. Verify that request will cause an
# appropriate exception.
with self.assertRaisesRegex(grpc.FutureTimeoutError, "timed out"):
future.result(0.01)

def test_async_call_with_retries_fails_immediately_on_permanent_error(self):
# Setup: Server which fails with an ISE.
def handler(_, context):
context.abort(grpc.StatusCode.INTERNAL, "death_ray")

server = TestGrpcServer(handler)
with server.run() as client:
# Execute `async_call_with_retries`
future = grpc_util.async_call_with_retries(
client.TestRpc,
make_request(42),
)
# Expect that the future raises an Exception which is the
# right type and carries the right message.
with self.assertRaises(grpc.RpcError) as raised:
future.result(2)
self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code())
self.assertEqual("death_ray", raised.exception.details())

def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self):
attempt_times = []
fake_time = test_util.FakeTime()

# Setup: Server which always fails with an UNAVAILABLE error.
def handler(_, context):
attempt_times.append(fake_time.time())
context.abort(
grpc.StatusCode.UNAVAILABLE, f"just a sec {len(attempt_times)}."
)

server = TestGrpcServer(handler)
with server.run() as client:
# Execute `async_call_with_retries` against the scripted server.
future = grpc_util.async_call_with_retries(
client.TestRpc,
make_request(42),
clock=fake_time,
)
# Expect that the future raises an Exception which is the right
# type and carries the right message.
with self.assertRaises(grpc.RpcError) as raised:
future.result(2)
self.assertEqual(
grpc.StatusCode.UNAVAILABLE, raised.exception.code()
)
self.assertEqual("just a sec 5.", raised.exception.details())
# Verify the number of attempts and delays between them.
self.assertLen(attempt_times, 5)
self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)
self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16)
self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32)

def test_async_with_retries_succeeds_after_backoff_on_transient_error(self):
attempt_times = []
fake_time = test_util.FakeTime()

# Setup: Server which passes on the third attempt.
def handler(request, context):
attempt_times.append(fake_time.time())
if len(attempt_times) < 3:
context.abort(grpc.StatusCode.UNAVAILABLE, "foo")
return make_response(request.nonce)

server = TestGrpcServer(handler)
with server.run() as client:
# Execute `async_call_with_retries` against the scripted server.
future = grpc_util.async_call_with_retries(
client.TestRpc,
make_request(42),
clock=fake_time,
)
# Expect:
# 1) The response contains the expected value.
# 2) The number of attempts and delays between them.
self.assertEqual(make_response(42), future.result(2))
self.assertLen(attempt_times, 3)
self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)


class VersionMetadataTest(tb_test.TestCase):
def test_structure(self):
result = grpc_util.version_metadata()
Expand Down