From 857dddc86c0fad378d800154ee903e69d7f96fbb Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 30 Mar 2021 16:08:35 -0400 Subject: [PATCH 01/11] fix timing with custom num retries --- tensorboard/util/grpc_util.py | 105 +++++++++++++++++-- tensorboard/util/grpc_util_test.py | 156 +++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+), 6 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 45a55d3ed9..9925c9f421 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -51,6 +51,104 @@ _VERSION_METADATA_KEY = "tensorboard-version" +def async_call( + api_method, + request, + completion_handler, + ): + """Call a gRPC stub API method. + + This only supports unary-unary RPCs: i.e., no streaming on either end. + Streamed RPCs will generally need application-level pagination support, + because after a gRPC error one must retry the entire request; there is no + "retry-resume" functionality. + + Args: + api_method: Callable for the API method to invoke. + request: Request protocol buffer to pass to the API method. + completion_handler: A callback which takes the resolved future as an + argument and completes the computation. + + Returns: + None. All computation relying on the return value of the gRPC should + be done in the completion_handler. + """ + # We can't actually use api_method.__name__ because it's not a real method, + # it's a special gRPC callable instance that doesn't expose the method name. + rpc_name = request.__class__.__name__.replace("Request", "") + logger.debug("Async RPC call %s with request: %r", rpc_name, request) + future = api_method.future( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata(), + ) + future.add_done_callback(completion_handler) + +def async_call_with_retries( + api_method, + request, + completion_handler, + num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, + clock=None + ): + """ TO DO DO NOT SUBMIT... + """ + print("calling async_call_with_retries") + if num_remaining_tries < 0: + # This should not happen in the course of normal operations and + # indicates a bug in the implementation. + raise ValueError( + "num_remaining_tries=%d. expected >= 0." % num_remaining_tries) + # We can't actually use api_method.__name__ because it's not a real method, + # it's a special gRPC callable instance that doesn't expose the method name. + rpc_name = request.__class__.__name__.replace("Request", "") + logger.debug("Async RPC call %s with request: %r", rpc_name, request) + future = api_method.future( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata(), + ) + # The continuation should wrap the completion_handler such that: + # * If the grpc call succeeds, we should invoke the completion_handler. + # * If there are no more retries, we should invoke the completion_handler. + # Otherwise, we should invoke async_call_with_retries with one less + # retry. + # + def retry_handler(future): + e = future.exception() + if e is None: + completion_handler(future) + return + else: + logger.info("RPC call %s got error %s", rpc_name, e) + # If unable to retry, proceed to completion_handler. + if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: + completion_handler(future) + return + if num_remaining_tries <= 0: + completion_handler(future) + return + # If able to retry, wait then do so. + num_attempts = _GRPC_RETRY_MAX_ATTEMPTS - num_remaining_tries + backoff_secs = _compute_backoff_seconds(num_attempts) + clock.sleep(backoff_secs) + async_call_with_retries( + api_method, request, completion_handler, num_remaining_tries - 1, clock) + + future.add_done_callback(retry_handler) + + + +def _compute_backoff_seconds(num_attempts): + """Compute wait time between attempts.""" + jitter_factor = random.uniform( + _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX + ) + backoff_secs = ( + _GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts + ) * jitter_factor + return backoff_secs + def call_with_retries(api_method, request, clock=None): """Call a gRPC stub API method, with automatic retry logic. @@ -93,12 +191,7 @@ def call_with_retries(api_method, request, clock=None): raise if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: raise - jitter_factor = random.uniform( - _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX - ) - backoff_secs = ( - _GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts - ) * jitter_factor + backoff_secs = _compute_backoff_seconds(num_attempts) logger.info( "RPC call %s attempted %d times, retrying in %.1f seconds", rpc_name, diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index efec8a41a0..3aed577b39 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -18,6 +18,7 @@ import contextlib import hashlib import threading +from unittest import mock from concurrent import futures import grpc @@ -161,6 +162,161 @@ def handler(request, context): self.assertEqual(make_response(expected_nonce), response) +class AsyncCallWithRetriesTest(tb_test.TestCase): + + def test_aync_call_with_retries_invokes_callback(self): + # Setup: Basic server, echos input. + def handler(request, _): + return make_response(request.nonce) + # Set up a callback which: + # 1) Records that it has been executed (mock_callback) + # 2) Triggers the keep_alive_event, notifying when it is ok + # to kill the server. + keep_alive_event = threading.Event() + mock_callback = mock.Mock() + def wrapped_mock_callback(future): + mock_callback(future) + keep_alive_event.set() + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + completion_handler=wrapped_mock_callback + ) + # Keep the test alive until the event is triggered. + keep_alive_event.wait() + + # Expect: callback is invoked. + mock_callback.assert_called_once() + + def test_aync_call_with_retries_succeeds(self): + # Setup: Basic server, echos input. + def handler(request, _): + return make_response(request.nonce) + # Set up a callback which: + # 1) Verifies the correct value has been returned in the future. + # 2) Triggers the keep_alive_event, notifying when it is ok + # to kill the server. + keep_alive_event = threading.Event() + def check_value(future): + self.assertEqual(make_response(42), future.result()) + keep_alive_event.set() + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + completion_handler=check_value + ) + # Keep the test alive until the event is triggered. + keep_alive_event.wait() + + def test_async_call_with_retries_fails_immediately_on_permanent_error(self): + # Setup: Server which fails with an ISE. + def handler(_, context): + context.abort(grpc.StatusCode.INTERNAL, "foo") + + # Set up a callback which: + # 1) Verifies the Exception against expectations. + # 2) Triggers the keep_alive_event, notifying when it is ok + # to kill the server. + keep_alive_event = threading.Event() + def check_exception(future): + raised = future.exception() + self.assertIsInstance(raised,grpc.RpcError) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.code()) + self.assertEqual("foo", raised.details()) + keep_alive_event.set() + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + completion_handler=check_exception + ) + # Keep the test alive until the event is triggered. + keep_alive_event.wait() + + + def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + # Setup: Server which always fails with an UNAVAILABLE error. + def handler(_, context): + attempt_times.append(fake_time.time()) + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + + # Set up a callback which: + # 1) Verifies the Exception against expectations. + # 2) Verifies the number of attempts and delays between them + # 3) Triggers the keep_alive_event, notifying when it is ok + # to kill the server. + keep_alive_event = threading.Event() + def check_exception(future): + raised = future.exception() + self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.code()) + self.assertEqual("foo", raised.details()) + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + keep_alive_event.set() + + + server = TestGrpcServer(handler) + with server.run() as client: + # Execute `async_call_with_retries` with the callback. + grpc_util.async_call_with_retries( + client.TestRpc, make_request(42), clock=fake_time, completion_handler=check_exception + ) + # Keep the test alive until the event is triggered. + keep_alive_event.wait() + + def test_async_with_retries_succeeds_after_backoff_on_transient_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + # Setup: Server which passes on the third attempt. + def handler(request, context): + attempt_times.append(fake_time.time()) + if len(attempt_times) < 3: + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + return make_response(request.nonce) + + # Set up a callback which: + # 1) Verifies the response contains the expected value + # 2) Verifies the number of attempts and delays between them + # 3) Triggers the keep_alive_event, notifying when it is ok + # to kill the server. + keep_alive_event = threading.Event() + def check_value(future): + self.assertEqual(make_response(42), future.result()) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + keep_alive_event.set() + + server = TestGrpcServer(handler) + with server.run() as client: + grpc_util.async_call_with_retries( + client.TestRpc, + make_request(42), + clock=fake_time, + completion_handler=check_value + ) + # Keep the test alive until the event is triggered. + keep_alive_event.wait() + + class VersionMetadataTest(tb_test.TestCase): def test_structure(self): result = grpc_util.version_metadata() From 9ccc07572ebd6674404a3e8fbcedaf8d2abb0977 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 30 Mar 2021 16:12:14 -0400 Subject: [PATCH 02/11] fix timing with custom num retries. Add test --- tensorboard/util/grpc_util.py | 80 ++++++++++++++++-------------- tensorboard/util/grpc_util_test.py | 35 ++++++++----- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 9925c9f421..1726254e70 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -50,50 +50,47 @@ # gRPC metadata key whose value contains the client version. _VERSION_METADATA_KEY = "tensorboard-version" - -def async_call( +def async_call_with_retries( api_method, request, completion_handler, + num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, + num_tries_so_far=0, + clock=None ): - """Call a gRPC stub API method. + """Initiate an asynchronous call to a gRPC stub, with retry logic. + + This is similar to the `async_call` API, except that the call is handled + asynchronously, and the completion may be handled by another thread. The + caller must provide a `completion_handler` argument which will handle the + result or exception rising from the gRPC completion. + + Retries are handled by recursively calling into this API with fewer + remaining retries, as controlled through the `num_remaining_retries` + argument. Setting `num_remaining_retries` to zero will make just + one attempt at the gRPC call. + + Retries are handled with jittered exponential backoff to spread out failures + due to request spikes. This only supports unary-unary RPCs: i.e., no streaming on either end. - Streamed RPCs will generally need application-level pagination support, - because after a gRPC error one must retry the entire request; there is no - "retry-resume" functionality. Args: api_method: Callable for the API method to invoke. request: Request protocol buffer to pass to the API method. - completion_handler: A callback which takes the resolved future as an - argument and completes the computation. - - Returns: - None. All computation relying on the return value of the gRPC should - be done in the completion_handler. - """ - # We can't actually use api_method.__name__ because it's not a real method, - # it's a special gRPC callable instance that doesn't expose the method name. - rpc_name = request.__class__.__name__.replace("Request", "") - logger.debug("Async RPC call %s with request: %r", rpc_name, request) - future = api_method.future( - request, - timeout=_GRPC_DEFAULT_TIMEOUT_SECS, - metadata=version_metadata(), - ) - future.add_done_callback(completion_handler) + completion_handler: A function which takes a `grpc.Future` object as an + argument and performs the necessary operations on the gRPC response + or error, as required. + num_remaining_retries: A non-negative integer which indicates how many + more attempts should be made to the gRPC endpoint if this try fails + within an error code which could be recovered from. Set to zero + to call with no retries. + num_tries_so_far: A non-negative integer indicating how many attempts + have been made so far for this gRPC. Used to compute backoff time. + clock: an interface object supporting `time()` and `sleep()` methods + like the standard `time` module; if not passed, uses the normal module. -def async_call_with_retries( - api_method, - request, - completion_handler, - num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, - clock=None - ): - """ TO DO DO NOT SUBMIT... """ - print("calling async_call_with_retries") if num_remaining_tries < 0: # This should not happen in the course of normal operations and # indicates a bug in the implementation. @@ -111,8 +108,8 @@ def async_call_with_retries( # The continuation should wrap the completion_handler such that: # * If the grpc call succeeds, we should invoke the completion_handler. # * If there are no more retries, we should invoke the completion_handler. - # Otherwise, we should invoke async_call_with_retries with one less - # retry. + # * Otherwise, we should invoke async_call_with_retries with one less + # retry. # def retry_handler(future): e = future.exception() @@ -129,18 +126,22 @@ def retry_handler(future): completion_handler(future) return # If able to retry, wait then do so. - num_attempts = _GRPC_RETRY_MAX_ATTEMPTS - num_remaining_tries - backoff_secs = _compute_backoff_seconds(num_attempts) + backoff_secs = _compute_backoff_seconds(num_tries_so_far + 1) clock.sleep(backoff_secs) async_call_with_retries( - api_method, request, completion_handler, num_remaining_tries - 1, clock) + api_method=api_method, + request=request, + completion_handler=completion_handler, + num_remaining_tries=num_remaining_tries - 1, + num_tries_so_far=num_tries_so_far + 1, + clock=clock) future.add_done_callback(retry_handler) def _compute_backoff_seconds(num_attempts): - """Compute wait time between attempts.""" + """Compute appropriate wait time between RPC attempts.""" jitter_factor = random.uniform( _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX ) @@ -157,6 +158,9 @@ def call_with_retries(api_method, request, clock=None): because after a gRPC error one must retry the entire request; there is no "retry-resume" functionality. + Retries are handled with jittered exponential backoff to spread out failures + due to request spikes. + Args: api_method: Callable for the API method to invoke. request: Request protocol buffer to pass to the API method. diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 3aed577b39..f2fbfaa711 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -169,7 +169,7 @@ def test_aync_call_with_retries_invokes_callback(self): def handler(request, _): return make_response(request.nonce) # Set up a callback which: - # 1) Records that it has been executed (mock_callback) + # 1) Records that it has been executed (mock_callback). # 2) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() @@ -219,10 +219,11 @@ def check_value(future): def test_async_call_with_retries_fails_immediately_on_permanent_error(self): # Setup: Server which fails with an ISE. def handler(_, context): - context.abort(grpc.StatusCode.INTERNAL, "foo") + context.abort(grpc.StatusCode.INTERNAL, "death_ray") # Set up a callback which: - # 1) Verifies the Exception against expectations. + # 1) Verifies the future raises an Exception which is the right type and + # carries the right message. # 2) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() @@ -230,7 +231,7 @@ def check_exception(future): raised = future.exception() self.assertIsInstance(raised,grpc.RpcError) self.assertEqual(grpc.StatusCode.INTERNAL, raised.code()) - self.assertEqual("foo", raised.details()) + self.assertEqual("death_ray", raised.details()) keep_alive_event.set() server = TestGrpcServer(handler) @@ -252,31 +253,39 @@ def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): # Setup: Server which always fails with an UNAVAILABLE error. def handler(_, context): attempt_times.append(fake_time.time()) - context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + context.abort( + grpc.StatusCode.UNAVAILABLE, + f"just a sec {len(attempt_times)}.") # Set up a callback which: - # 1) Verifies the Exception against expectations. - # 2) Verifies the number of attempts and delays between them + # 1) Verifies the final raised Exception against expectations. + # 2) Verifies the number of attempts and delays between them. # 3) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() def check_exception(future): raised = future.exception() self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.code()) - self.assertEqual("foo", raised.details()) - self.assertLen(attempt_times, 5) + self.assertEqual("just a sec 6.", raised.details()) + self.assertLen(attempt_times, 6) self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + self.assertBetween(attempt_times[5] - attempt_times[4], 32, 64) keep_alive_event.set() server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` with the callback. + # Execute `async_call_with_retries` with the callback. Call + # with explicit num retries (5) instead of default. grpc_util.async_call_with_retries( - client.TestRpc, make_request(42), clock=fake_time, completion_handler=check_exception + client.TestRpc, + make_request(42), + num_remaining_tries=5, + clock=fake_time, + completion_handler=check_exception ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -293,8 +302,8 @@ def handler(request, context): return make_response(request.nonce) # Set up a callback which: - # 1) Verifies the response contains the expected value - # 2) Verifies the number of attempts and delays between them + # 1) Verifies the response contains the expected value. + # 2) Verifies the number of attempts and delays between them. # 3) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() From 8e11592cb9cf46bc8873dae0e7e701689440d4c8 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 30 Mar 2021 16:52:45 -0400 Subject: [PATCH 03/11] black --- tensorboard/util/grpc_util.py | 13 ++++++++----- tensorboard/util/grpc_util_test.py | 28 +++++++++++++++------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 1726254e70..49844b64a5 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -50,14 +50,15 @@ # gRPC metadata key whose value contains the client version. _VERSION_METADATA_KEY = "tensorboard-version" + def async_call_with_retries( api_method, request, completion_handler, num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, num_tries_so_far=0, - clock=None - ): + clock=None, +): """Initiate an asynchronous call to a gRPC stub, with retry logic. This is similar to the `async_call` API, except that the call is handled @@ -95,7 +96,8 @@ def async_call_with_retries( # This should not happen in the course of normal operations and # indicates a bug in the implementation. raise ValueError( - "num_remaining_tries=%d. expected >= 0." % num_remaining_tries) + "num_remaining_tries=%d. expected >= 0." % num_remaining_tries + ) # We can't actually use api_method.__name__ because it's not a real method, # it's a special gRPC callable instance that doesn't expose the method name. rpc_name = request.__class__.__name__.replace("Request", "") @@ -134,12 +136,12 @@ def retry_handler(future): completion_handler=completion_handler, num_remaining_tries=num_remaining_tries - 1, num_tries_so_far=num_tries_so_far + 1, - clock=clock) + clock=clock, + ) future.add_done_callback(retry_handler) - def _compute_backoff_seconds(num_attempts): """Compute appropriate wait time between RPC attempts.""" jitter_factor = random.uniform( @@ -150,6 +152,7 @@ def _compute_backoff_seconds(num_attempts): ) * jitter_factor return backoff_secs + def call_with_retries(api_method, request, clock=None): """Call a gRPC stub API method, with automatic retry logic. diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index f2fbfaa711..0d8f6689c3 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -163,17 +163,18 @@ def handler(request, context): class AsyncCallWithRetriesTest(tb_test.TestCase): - def test_aync_call_with_retries_invokes_callback(self): # Setup: Basic server, echos input. def handler(request, _): return make_response(request.nonce) + # Set up a callback which: # 1) Records that it has been executed (mock_callback). # 2) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() mock_callback = mock.Mock() + def wrapped_mock_callback(future): mock_callback(future) keep_alive_event.set() @@ -184,7 +185,7 @@ def wrapped_mock_callback(future): grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - completion_handler=wrapped_mock_callback + completion_handler=wrapped_mock_callback, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -196,11 +197,13 @@ def test_aync_call_with_retries_succeeds(self): # Setup: Basic server, echos input. def handler(request, _): return make_response(request.nonce) + # Set up a callback which: # 1) Verifies the correct value has been returned in the future. # 2) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() + def check_value(future): self.assertEqual(make_response(42), future.result()) keep_alive_event.set() @@ -209,9 +212,7 @@ def check_value(future): with server.run() as client: # Execute `async_call_with_retries` with the callback. grpc_util.async_call_with_retries( - client.TestRpc, - make_request(42), - completion_handler=check_value + client.TestRpc, make_request(42), completion_handler=check_value ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -227,9 +228,10 @@ def handler(_, context): # 2) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() + def check_exception(future): raised = future.exception() - self.assertIsInstance(raised,grpc.RpcError) + self.assertIsInstance(raised, grpc.RpcError) self.assertEqual(grpc.StatusCode.INTERNAL, raised.code()) self.assertEqual("death_ray", raised.details()) keep_alive_event.set() @@ -240,12 +242,11 @@ def check_exception(future): grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - completion_handler=check_exception + completion_handler=check_exception, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() - def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): attempt_times = [] fake_time = test_util.FakeTime() @@ -254,8 +255,8 @@ def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): def handler(_, context): attempt_times.append(fake_time.time()) context.abort( - grpc.StatusCode.UNAVAILABLE, - f"just a sec {len(attempt_times)}.") + grpc.StatusCode.UNAVAILABLE, f"just a sec {len(attempt_times)}." + ) # Set up a callback which: # 1) Verifies the final raised Exception against expectations. @@ -263,6 +264,7 @@ def handler(_, context): # 3) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() + def check_exception(future): raised = future.exception() self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.code()) @@ -275,7 +277,6 @@ def check_exception(future): self.assertBetween(attempt_times[5] - attempt_times[4], 32, 64) keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: # Execute `async_call_with_retries` with the callback. Call @@ -285,7 +286,7 @@ def check_exception(future): make_request(42), num_remaining_tries=5, clock=fake_time, - completion_handler=check_exception + completion_handler=check_exception, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -307,6 +308,7 @@ def handler(request, context): # 3) Triggers the keep_alive_event, notifying when it is ok # to kill the server. keep_alive_event = threading.Event() + def check_value(future): self.assertEqual(make_response(42), future.result()) self.assertLen(attempt_times, 3) @@ -320,7 +322,7 @@ def check_value(future): client.TestRpc, make_request(42), clock=fake_time, - completion_handler=check_value + completion_handler=check_value, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() From ec14e18d3af21a926e204e04ddb20becbe9a8f14 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Thu, 1 Apr 2021 16:03:22 -0400 Subject: [PATCH 04/11] Oops, no time! --- tensorboard/util/grpc_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 49844b64a5..388c02e3d7 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -92,6 +92,8 @@ def async_call_with_retries( like the standard `time` module; if not passed, uses the normal module. """ + if clock is None: + clock = time if num_remaining_tries < 0: # This should not happen in the course of normal operations and # indicates a bug in the implementation. From 7af8eafedd2dffd6b076d95ed969a8426d992e62 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Mon, 12 Apr 2021 11:56:01 -0400 Subject: [PATCH 05/11] log.info on RPC retry --- tensorboard/util/grpc_util.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 388c02e3d7..c65faa698d 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -131,6 +131,12 @@ def retry_handler(future): return # If able to retry, wait then do so. backoff_secs = _compute_backoff_seconds(num_tries_so_far + 1) + logger.info( + "RPC call %s attempted %d times, retrying in %.1f seconds", + rpc_name, + num_tries_so_far, + backoff_secs, + ) clock.sleep(backoff_secs) async_call_with_retries( api_method=api_method, From e8847c99669f83666dd0d1f794b277b128ec84e4 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Mon, 12 Apr 2021 12:10:21 -0400 Subject: [PATCH 06/11] Hide recursion-oriented args from public API --- tensorboard/util/grpc_util.py | 56 +++++++++++++++++++++++------- tensorboard/util/grpc_util_test.py | 9 ++--- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index c65faa698d..b34d53bd98 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -52,12 +52,7 @@ def async_call_with_retries( - api_method, - request, - completion_handler, - num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, - num_tries_so_far=0, - clock=None, + api_method, request, completion_handler, clock=None ): """Initiate an asynchronous call to a gRPC stub, with retry logic. @@ -66,11 +61,6 @@ def async_call_with_retries( caller must provide a `completion_handler` argument which will handle the result or exception rising from the gRPC completion. - Retries are handled by recursively calling into this API with fewer - remaining retries, as controlled through the `num_remaining_retries` - argument. Setting `num_remaining_retries` to zero will make just - one attempt at the gRPC call. - Retries are handled with jittered exponential backoff to spread out failures due to request spikes. @@ -92,6 +82,46 @@ def async_call_with_retries( like the standard `time` module; if not passed, uses the normal module. """ + return _async_call_with_retries( + api_method=api_method, + request=request, + completion_handler=completion_handler, + num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, + num_tries_so_far=0, + clock=clock, + ) + + +def _async_call_with_retries( + api_method, + request, + completion_handler, + clock=None, + num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, + num_tries_so_far=0, +): + """Helps `async_call_with_retries` recursion by exposing depth as args. + + Retries are handled by recursively calling into this API with fewer + remaining retries, as controlled through the `num_remaining_retries` + argument. Setting `num_remaining_retries` to zero will make just + one attempt at the gRPC call. + + See `async_call_with_retries` documentation for details on expected usage. + + Args: + api_method: See `async_call_with_retries`. + request: See `async_call_with_retries`. + completion_handler: See `async_call_with_retries`. + clock: See `async_call_with_retries`. + num_remaining_retries: A non-negative integer which indicates how many + more attempts should be made to the gRPC endpoint if this try fails + within an error code which could be recovered from. Set to zero + to call with no retries. + num_tries_so_far: A non-negative integer indicating how many attempts + have been made so far for this gRPC. Used to compute backoff time. + """ + if clock is None: clock = time if num_remaining_tries < 0: @@ -112,7 +142,7 @@ def async_call_with_retries( # The continuation should wrap the completion_handler such that: # * If the grpc call succeeds, we should invoke the completion_handler. # * If there are no more retries, we should invoke the completion_handler. - # * Otherwise, we should invoke async_call_with_retries with one less + # * Otherwise, we should invoke _async_call_with_retries with one less # retry. # def retry_handler(future): @@ -138,7 +168,7 @@ def retry_handler(future): backoff_secs, ) clock.sleep(backoff_secs) - async_call_with_retries( + _async_call_with_retries( api_method=api_method, request=request, completion_handler=completion_handler, diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 0d8f6689c3..9e8c4c8d2b 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -268,23 +268,20 @@ def handler(_, context): def check_exception(future): raised = future.exception() self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.code()) - self.assertEqual("just a sec 6.", raised.details()) - self.assertLen(attempt_times, 6) + self.assertEqual("just a sec 5.", raised.details()) + self.assertLen(attempt_times, 5) self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) - self.assertBetween(attempt_times[5] - attempt_times[4], 32, 64) keep_alive_event.set() server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` with the callback. Call - # with explicit num retries (5) instead of default. + # Execute `async_call_with_retries` with the callback. grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - num_remaining_tries=5, clock=fake_time, completion_handler=check_exception, ) From 92c310c2c041625b20aabacdbf3b668aa9325177 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Mon, 12 Apr 2021 15:00:11 -0400 Subject: [PATCH 07/11] Rename completion_handler -> done_callback --- tensorboard/util/grpc_util.py | 31 +++++++++++++++--------------- tensorboard/util/grpc_util_test.py | 10 +++++----- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index b34d53bd98..7f84620b9b 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -52,13 +52,13 @@ def async_call_with_retries( - api_method, request, completion_handler, clock=None + api_method, request, done_callback, clock=None ): """Initiate an asynchronous call to a gRPC stub, with retry logic. This is similar to the `async_call` API, except that the call is handled asynchronously, and the completion may be handled by another thread. The - caller must provide a `completion_handler` argument which will handle the + caller must provide a `done_callback` argument which will handle the result or exception rising from the gRPC completion. Retries are handled with jittered exponential backoff to spread out failures @@ -69,9 +69,10 @@ def async_call_with_retries( Args: api_method: Callable for the API method to invoke. request: Request protocol buffer to pass to the API method. - completion_handler: A function which takes a `grpc.Future` object as an + done_callback: A function which takes a `grpc.Future` object as an argument and performs the necessary operations on the gRPC response - or error, as required. + or error, as required. See the gRPC documentation for more details + https://grpc.github.io/grpc/python/grpc.html#grpc.Future.add_done_callback num_remaining_retries: A non-negative integer which indicates how many more attempts should be made to the gRPC endpoint if this try fails within an error code which could be recovered from. Set to zero @@ -85,7 +86,7 @@ def async_call_with_retries( return _async_call_with_retries( api_method=api_method, request=request, - completion_handler=completion_handler, + done_callback=done_callback, num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, num_tries_so_far=0, clock=clock, @@ -95,7 +96,7 @@ def async_call_with_retries( def _async_call_with_retries( api_method, request, - completion_handler, + done_callback, clock=None, num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, num_tries_so_far=0, @@ -112,7 +113,7 @@ def _async_call_with_retries( Args: api_method: See `async_call_with_retries`. request: See `async_call_with_retries`. - completion_handler: See `async_call_with_retries`. + done_callback: See `async_call_with_retries`. clock: See `async_call_with_retries`. num_remaining_retries: A non-negative integer which indicates how many more attempts should be made to the gRPC endpoint if this try fails @@ -139,25 +140,25 @@ def _async_call_with_retries( timeout=_GRPC_DEFAULT_TIMEOUT_SECS, metadata=version_metadata(), ) - # The continuation should wrap the completion_handler such that: - # * If the grpc call succeeds, we should invoke the completion_handler. - # * If there are no more retries, we should invoke the completion_handler. + # The continuation should wrap the done_callback such that: + # * If the grpc call succeeds, we should invoke the done_callback. + # * If there are no more retries, we should invoke the done_callback. # * Otherwise, we should invoke _async_call_with_retries with one less # retry. # def retry_handler(future): e = future.exception() if e is None: - completion_handler(future) + done_callback(future) return else: logger.info("RPC call %s got error %s", rpc_name, e) - # If unable to retry, proceed to completion_handler. + # If unable to retry, proceed to done_callback. if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: - completion_handler(future) + done_callback(future) return if num_remaining_tries <= 0: - completion_handler(future) + done_callback(future) return # If able to retry, wait then do so. backoff_secs = _compute_backoff_seconds(num_tries_so_far + 1) @@ -171,7 +172,7 @@ def retry_handler(future): _async_call_with_retries( api_method=api_method, request=request, - completion_handler=completion_handler, + done_callback=done_callback, num_remaining_tries=num_remaining_tries - 1, num_tries_so_far=num_tries_so_far + 1, clock=clock, diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 9e8c4c8d2b..59f12b37cc 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -185,7 +185,7 @@ def wrapped_mock_callback(future): grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - completion_handler=wrapped_mock_callback, + done_callback=wrapped_mock_callback, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -212,7 +212,7 @@ def check_value(future): with server.run() as client: # Execute `async_call_with_retries` with the callback. grpc_util.async_call_with_retries( - client.TestRpc, make_request(42), completion_handler=check_value + client.TestRpc, make_request(42), done_callback=check_value ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -242,7 +242,7 @@ def check_exception(future): grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - completion_handler=check_exception, + done_callback=check_exception, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -283,7 +283,7 @@ def check_exception(future): client.TestRpc, make_request(42), clock=fake_time, - completion_handler=check_exception, + done_callback=check_exception, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() @@ -319,7 +319,7 @@ def check_value(future): client.TestRpc, make_request(42), clock=fake_time, - completion_handler=check_value, + done_callback=check_value, ) # Keep the test alive until the event is triggered. keep_alive_event.wait() From b9cfd5bdcf748329c4d87b9c2f8fc227a6906afc Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Mon, 12 Apr 2021 15:08:19 -0400 Subject: [PATCH 08/11] black --- tensorboard/util/grpc_util.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 7f84620b9b..adcd3236c7 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -51,9 +51,7 @@ _VERSION_METADATA_KEY = "tensorboard-version" -def async_call_with_retries( - api_method, request, done_callback, clock=None -): +def async_call_with_retries(api_method, request, done_callback, clock=None): """Initiate an asynchronous call to a gRPC stub, with retry logic. This is similar to the `async_call` API, except that the call is handled From fa874c6cc053aad8bead0f5ce2da447ef5f5d4f3 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 13 Apr 2021 13:17:49 -0400 Subject: [PATCH 09/11] Async call now returns an AsyncCallFuture object --- tensorboard/util/grpc_util.py | 183 ++++++++++++++--------------- tensorboard/util/grpc_util_test.py | 136 +++++++-------------- 2 files changed, 132 insertions(+), 187 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index adcd3236c7..d8f0e389b9 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -16,7 +16,9 @@ import enum +import functools import random +import threading import time import grpc @@ -51,7 +53,56 @@ _VERSION_METADATA_KEY = "tensorboard-version" -def async_call_with_retries(api_method, request, done_callback, clock=None): +class AsyncCallFuture: + """Encapsulates the future value of a retriable async gRPC request. + + Abstracts over the set of futures returned by a set of gRPC calls + comprising a single logical gRPC request with retries. Communicates + to the caller the result or exception resulting from the request. + + Args: + completion_event: The constructor should provide a thredding.Event which + will be used to communicate when the set of gRPC requests is complete. + """ + + def __init__(self, completion_event): + self._active_grpc_future = None + self._active_grpc_future_lock = threading.Lock() + self._completion_event = completion_event + + def _set_active_future(self, grpc_future): + if grpc_future is None: + raise RuntimeError( + "_set_active_future invoked with grpc_future=None" + ) + # TODO() maybe add validation logic here that future is not None. + with self._active_grpc_future_lock: + self._active_grpc_future = grpc_future + + def result(self, timeout): + """Analogous to `grpc.Future.result`. Returns the value or exception. + + This method will wait until the full set of gRPC requests is complete + and then act as `grpc.Future.result` for the single gRPC invocation + corresponding to the first successful call or final failure, as + appropriate. + + Args: + timeout: How long to wait in seconds before giving up and raising + + Returns: + The result of the future corresponding to the single gRPC + corresponding to the successful call or final failure, as + appropriate. + """ + self._completion_event.wait(timeout) + with self._active_grpc_future_lock: + if self._active_grpc_future is None: + raise RuntimeError("AsyncFuture never had an active future set") + return self._active_grpc_future.result() + + +def async_call_with_retries(api_method, request, clock=None): """Initiate an asynchronous call to a gRPC stub, with retry logic. This is similar to the `async_call` API, except that the call is handled @@ -67,116 +118,62 @@ def async_call_with_retries(api_method, request, done_callback, clock=None): Args: api_method: Callable for the API method to invoke. request: Request protocol buffer to pass to the API method. - done_callback: A function which takes a `grpc.Future` object as an - argument and performs the necessary operations on the gRPC response - or error, as required. See the gRPC documentation for more details - https://grpc.github.io/grpc/python/grpc.html#grpc.Future.add_done_callback - num_remaining_retries: A non-negative integer which indicates how many - more attempts should be made to the gRPC endpoint if this try fails - within an error code which could be recovered from. Set to zero - to call with no retries. - num_tries_so_far: A non-negative integer indicating how many attempts - have been made so far for this gRPC. Used to compute backoff time. clock: an interface object supporting `time()` and `sleep()` methods like the standard `time` module; if not passed, uses the normal module. + Returns: + An `AsyncCallFuture` which will encapsulate the `grpc.Future` + corresponding to the gRPC call which either completes successfully or + represents the final try. """ - return _async_call_with_retries( - api_method=api_method, - request=request, - done_callback=done_callback, - num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, - num_tries_so_far=0, - clock=clock, - ) - - -def _async_call_with_retries( - api_method, - request, - done_callback, - clock=None, - num_remaining_tries=_GRPC_RETRY_MAX_ATTEMPTS - 1, - num_tries_so_far=0, -): - """Helps `async_call_with_retries` recursion by exposing depth as args. - - Retries are handled by recursively calling into this API with fewer - remaining retries, as controlled through the `num_remaining_retries` - argument. Setting `num_remaining_retries` to zero will make just - one attempt at the gRPC call. - - See `async_call_with_retries` documentation for details on expected usage. - - Args: - api_method: See `async_call_with_retries`. - request: See `async_call_with_retries`. - done_callback: See `async_call_with_retries`. - clock: See `async_call_with_retries`. - num_remaining_retries: A non-negative integer which indicates how many - more attempts should be made to the gRPC endpoint if this try fails - within an error code which could be recovered from. Set to zero - to call with no retries. - num_tries_so_far: A non-negative integer indicating how many attempts - have been made so far for this gRPC. Used to compute backoff time. - """ - if clock is None: clock = time - if num_remaining_tries < 0: - # This should not happen in the course of normal operations and - # indicates a bug in the implementation. - raise ValueError( - "num_remaining_tries=%d. expected >= 0." % num_remaining_tries + logger.debug("Async RPC call %s with request: %r", api_method, request) + + completion_event = threading.Event() + async_future = AsyncCallFuture(completion_event) + + def async_call(handler): + """Invokes the gRPC future and orchestrates it via the AsyncCallFuture.""" + future = api_method.future( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata(), ) - # We can't actually use api_method.__name__ because it's not a real method, - # it's a special gRPC callable instance that doesn't expose the method name. - rpc_name = request.__class__.__name__.replace("Request", "") - logger.debug("Async RPC call %s with request: %r", rpc_name, request) - future = api_method.future( - request, - timeout=_GRPC_DEFAULT_TIMEOUT_SECS, - metadata=version_metadata(), - ) - # The continuation should wrap the done_callback such that: - # * If the grpc call succeeds, we should invoke the done_callback. - # * If there are no more retries, we should invoke the done_callback. - # * Otherwise, we should invoke _async_call_with_retries with one less - # retry. - # - def retry_handler(future): + # Ensure we set the active future before invoking thedone callback, to avoid + # the case where the done callback completes immediately and triggers + # completion event while async_future still holds the old future. + async_future._set_active_future(future) + future.add_done_callback(handler) + + # retry_handler is the continuation of the `async_call`. It should: + # * If the grpc call succeeds: trigger the `completion_event`. + # * If there are no more retries: trigger the `completion_event`. + # * Otherwise, invoke a new async_call with the same + # retry_handler. + def retry_handler(future, num_attempts): e = future.exception() if e is None: - done_callback(future) + completion_event.set() return else: - logger.info("RPC call %s got error %s", rpc_name, e) - # If unable to retry, proceed to done_callback. + logger.info("RPC call %s got error %s", api_method, e) + # If unable to retry, proceed to completion. if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: - done_callback(future) + completion_event.set() return - if num_remaining_tries <= 0: - done_callback(future) + if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: + completion_event.set() return # If able to retry, wait then do so. - backoff_secs = _compute_backoff_seconds(num_tries_so_far + 1) - logger.info( - "RPC call %s attempted %d times, retrying in %.1f seconds", - rpc_name, - num_tries_so_far, - backoff_secs, - ) + backoff_secs = _compute_backoff_seconds(num_attempts) clock.sleep(backoff_secs) - _async_call_with_retries( - api_method=api_method, - request=request, - done_callback=done_callback, - num_remaining_tries=num_remaining_tries - 1, - num_tries_so_far=num_tries_so_far + 1, - clock=clock, + async_call( + functools.partial(retry_handler, num_attempts=num_attempts + 1) ) - future.add_done_callback(retry_handler) + async_call(functools.partial(retry_handler, num_attempts=1)) + return async_future def _compute_backoff_seconds(num_attempts): diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 59f12b37cc..2315e2e20b 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -163,89 +163,53 @@ def handler(request, context): class AsyncCallWithRetriesTest(tb_test.TestCase): - def test_aync_call_with_retries_invokes_callback(self): + def test_aync_call_with_retries_completes(self): # Setup: Basic server, echos input. def handler(request, _): return make_response(request.nonce) - # Set up a callback which: - # 1) Records that it has been executed (mock_callback). - # 2) Triggers the keep_alive_event, notifying when it is ok - # to kill the server. - keep_alive_event = threading.Event() - mock_callback = mock.Mock() - - def wrapped_mock_callback(future): - mock_callback(future) - keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` with the callback. - grpc_util.async_call_with_retries( + # Execute `async_call_with_retries` + future = grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - done_callback=wrapped_mock_callback, ) - # Keep the test alive until the event is triggered. - keep_alive_event.wait() - - # Expect: callback is invoked. - mock_callback.assert_called_once() + # Wait for completion & collect result + future.result(2) def test_aync_call_with_retries_succeeds(self): # Setup: Basic server, echos input. def handler(request, _): return make_response(request.nonce) - # Set up a callback which: - # 1) Verifies the correct value has been returned in the future. - # 2) Triggers the keep_alive_event, notifying when it is ok - # to kill the server. - keep_alive_event = threading.Event() - - def check_value(future): - self.assertEqual(make_response(42), future.result()) - keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: # Execute `async_call_with_retries` with the callback. - grpc_util.async_call_with_retries( - client.TestRpc, make_request(42), done_callback=check_value + future = grpc_util.async_call_with_retries( + client.TestRpc, make_request(42) ) - # Keep the test alive until the event is triggered. - keep_alive_event.wait() + # Verify the correct value has been returned in the future. + self.assertEqual(make_response(42), future.result(2)) def test_async_call_with_retries_fails_immediately_on_permanent_error(self): # Setup: Server which fails with an ISE. def handler(_, context): context.abort(grpc.StatusCode.INTERNAL, "death_ray") - # Set up a callback which: - # 1) Verifies the future raises an Exception which is the right type and - # carries the right message. - # 2) Triggers the keep_alive_event, notifying when it is ok - # to kill the server. - keep_alive_event = threading.Event() - - def check_exception(future): - raised = future.exception() - self.assertIsInstance(raised, grpc.RpcError) - self.assertEqual(grpc.StatusCode.INTERNAL, raised.code()) - self.assertEqual("death_ray", raised.details()) - keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` with the callback. - grpc_util.async_call_with_retries( + # Execute `async_call_with_retries` + future = grpc_util.async_call_with_retries( client.TestRpc, make_request(42), - done_callback=check_exception, ) - # Keep the test alive until the event is triggered. - keep_alive_event.wait() + # Expect that the future raises an Exception which is the + # right type and carries the right message. + with self.assertRaises(grpc.RpcError) as raised: + future.result(2) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) + self.assertEqual("death_ray", raised.exception.details()) def test_async_with_retries_fails_after_backoff_on_nonpermanent_error(self): attempt_times = [] @@ -258,35 +222,28 @@ def handler(_, context): grpc.StatusCode.UNAVAILABLE, f"just a sec {len(attempt_times)}." ) - # Set up a callback which: - # 1) Verifies the final raised Exception against expectations. - # 2) Verifies the number of attempts and delays between them. - # 3) Triggers the keep_alive_event, notifying when it is ok - # to kill the server. - keep_alive_event = threading.Event() - - def check_exception(future): - raised = future.exception() - self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.code()) - self.assertEqual("just a sec 5.", raised.details()) - self.assertLen(attempt_times, 5) - self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) - self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) - self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) - self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) - keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` with the callback. - grpc_util.async_call_with_retries( + # Execute `async_call_with_retries` against the scripted server. + future = grpc_util.async_call_with_retries( client.TestRpc, make_request(42), clock=fake_time, - done_callback=check_exception, ) - # Keep the test alive until the event is triggered. - keep_alive_event.wait() + # Expect that the future raises an Exception which is the right + # type and carries the right message. + with self.assertRaises(grpc.RpcError) as raised: + future.result(2) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, raised.exception.code() + ) + self.assertEqual("just a sec 5.", raised.exception.details()) + # Verify the number of attempts and delays between them. + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) def test_async_with_retries_succeeds_after_backoff_on_transient_error(self): attempt_times = [] @@ -299,30 +256,21 @@ def handler(request, context): context.abort(grpc.StatusCode.UNAVAILABLE, "foo") return make_response(request.nonce) - # Set up a callback which: - # 1) Verifies the response contains the expected value. - # 2) Verifies the number of attempts and delays between them. - # 3) Triggers the keep_alive_event, notifying when it is ok - # to kill the server. - keep_alive_event = threading.Event() - - def check_value(future): - self.assertEqual(make_response(42), future.result()) - self.assertLen(attempt_times, 3) - self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) - self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) - keep_alive_event.set() - server = TestGrpcServer(handler) with server.run() as client: - grpc_util.async_call_with_retries( + # Execute `async_call_with_retries` against the scripted server. + future = grpc_util.async_call_with_retries( client.TestRpc, make_request(42), clock=fake_time, - done_callback=check_value, ) - # Keep the test alive until the event is triggered. - keep_alive_event.wait() + # Expect: + # 1) The response contains the expected value. + # 2) The number of attempts and delays between them. + self.assertEqual(make_response(42), future.result(2)) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) class VersionMetadataTest(tb_test.TestCase): From 1191037dfb6c9b397fa9e97bedc421b6821f16fc Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 13 Apr 2021 13:23:49 -0400 Subject: [PATCH 10/11] Remove unused import "mock" --- tensorboard/util/grpc_util_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 2315e2e20b..21d83ba8c8 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -18,7 +18,6 @@ import contextlib import hashlib import threading -from unittest import mock from concurrent import futures import grpc From 7a8a24d2065e7a86c5f62191d9cd748facdc87dd Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Wed, 14 Apr 2021 11:59:08 -0400 Subject: [PATCH 11/11] Reviewer comments. Fix an issue with timeout exception raising. --- tensorboard/util/grpc_util.py | 28 ++++++++++++++++++---------- tensorboard/util/grpc_util_test.py | 24 ++++++++++++++---------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index d8f0e389b9..7e5fce2179 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -61,7 +61,7 @@ class AsyncCallFuture: to the caller the result or exception resulting from the request. Args: - completion_event: The constructor should provide a thredding.Event which + completion_event: The constructor should provide a `threding.Event` which will be used to communicate when the set of gRPC requests is complete. """ @@ -73,9 +73,8 @@ def __init__(self, completion_event): def _set_active_future(self, grpc_future): if grpc_future is None: raise RuntimeError( - "_set_active_future invoked with grpc_future=None" + "_set_active_future invoked with grpc_future=None." ) - # TODO() maybe add validation logic here that future is not None. with self._active_grpc_future_lock: self._active_grpc_future = grpc_future @@ -88,14 +87,22 @@ def result(self, timeout): appropriate. Args: - timeout: How long to wait in seconds before giving up and raising + timeout: How long to wait in seconds before giving up and raising. Returns: The result of the future corresponding to the single gRPC - corresponding to the successful call or final failure, as - appropriate. + corresponding to the successful call. + + Raises: + * `grpc.FutureTimeoutError` if timeout seconds elapse before the gRPC + calls could complete, including waits and retries. + * The exception corresponding to the last non-retryable gRPC request + in the case that a successful gRPC request was not made. """ - self._completion_event.wait(timeout) + if not self._completion_event.wait(timeout): + raise grpc.FutureTimeoutError( + f"AsyncCallFuture timed out after {timeout} seconds" + ) with self._active_grpc_future_lock: if self._active_grpc_future is None: raise RuntimeError("AsyncFuture never had an active future set") @@ -140,9 +147,10 @@ def async_call(handler): timeout=_GRPC_DEFAULT_TIMEOUT_SECS, metadata=version_metadata(), ) - # Ensure we set the active future before invoking thedone callback, to avoid - # the case where the done callback completes immediately and triggers - # completion event while async_future still holds the old future. + # Ensure we set the active future before invoking the done callback, to + # avoid the case where the done callback completes immediately and + # triggers completion event while async_future still holds the old + # future. async_future._set_active_future(future) future.add_done_callback(handler) diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index 21d83ba8c8..dc566401cf 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -18,6 +18,7 @@ import contextlib import hashlib import threading +import time from concurrent import futures import grpc @@ -162,24 +163,24 @@ def handler(request, context): class AsyncCallWithRetriesTest(tb_test.TestCase): - def test_aync_call_with_retries_completes(self): + def test_aync_call_with_retries_succeeds(self): # Setup: Basic server, echos input. def handler(request, _): return make_response(request.nonce) server = TestGrpcServer(handler) with server.run() as client: - # Execute `async_call_with_retries` + # Execute `async_call_with_retries` with the callback. future = grpc_util.async_call_with_retries( - client.TestRpc, - make_request(42), + client.TestRpc, make_request(42) ) - # Wait for completion & collect result - future.result(2) + # Verify the correct value has been returned in the future. + self.assertEqual(make_response(42), future.result(2)) - def test_aync_call_with_retries_succeeds(self): - # Setup: Basic server, echos input. + def test_aync_call_raises_at_timeout(self): + # Setup: Server waits 0.5 seconds before echoing. def handler(request, _): + time.sleep(0.5) return make_response(request.nonce) server = TestGrpcServer(handler) @@ -188,8 +189,11 @@ def handler(request, _): future = grpc_util.async_call_with_retries( client.TestRpc, make_request(42) ) - # Verify the correct value has been returned in the future. - self.assertEqual(make_response(42), future.result(2)) + # Request the result in 0.01s, critically less time than the server + # will take to respond. Verify that request will cause an + # appropriate exception. + with self.assertRaisesRegex(grpc.FutureTimeoutError, "timed out"): + future.result(0.01) def test_async_call_with_retries_fails_immediately_on_permanent_error(self): # Setup: Server which fails with an ISE.