From 653f65dcbe7f3dce44ecf1c5abf759be7e867cc1 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 22 Sep 2021 17:21:32 +0100 Subject: [PATCH] [env] Wrap all RPC calls in reset() in retry loop. --- compiler_gym/envs/compiler_env.py | 54 +++++++++++++++++++------------ 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index 993cbad55..b62d010ba 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -705,6 +705,32 @@ def reset( # pylint: disable=arguments-differ :raises TypeError: If no benchmark has been set, and the environment does not have a default benchmark to select from. """ + + def _call_with_retry(stub_method, *args, **kwargs): + """Call the given stub method. If it fails with an "acceptable" + error, abort this reset and retry. + """ + try: + return self.service(stub_method, *args, **kwargs) + except (ServiceError, ServiceTransportError, TimeoutError) as e: + # Abort and retry on error. + self.logger.warning("%s on reset(): %s", type(e).__name__, e) + if self.service: + self.service.close() + self.service = None + + if retry_count >= self._connection_settings.init_max_attempts: + raise OSError( + f"Failed to reset environment after {retry_count - 1} attempts.\n" + f"Last error ({type(e).__name__}): {e}" + ) from e + else: + return self.reset( + benchmark=benchmark, + action_space=action_space, + retry_count=retry_count + 1, + ) + if not self._next_benchmark: raise TypeError( "No benchmark set. Set a benchmark using " @@ -723,7 +749,7 @@ def reset( # pylint: disable=arguments-differ # Stop an existing episode. if self.in_episode: self.logger.debug("Ending session %d", self._session_id) - self.service( + _call_with_retry( self.service.stub.EndSession, EndSessionRequest(session_id=self._session_id), ) @@ -759,7 +785,9 @@ def reset( # pylint: disable=arguments-differ ) try: - reply = self.service(self.service.stub.StartSession, start_session_request) + reply = _call_with_retry( + self.service.stub.StartSession, start_session_request + ) except FileNotFoundError: # The benchmark was not found, so try adding it and repeating the # request. @@ -767,25 +795,9 @@ def reset( # pylint: disable=arguments-differ self.service.stub.AddBenchmark, AddBenchmarkRequest(benchmark=[self._benchmark_in_use.proto]), ) - reply = self.service(self.service.stub.StartSession, start_session_request) - except (ServiceError, ServiceTransportError, TimeoutError) as e: - # Abort and retry on error. - self.logger.warning("%s on reset(): %s", type(e).__name__, e) - if self.service: - self.service.close() - self.service = None - - if retry_count >= self._connection_settings.init_max_attempts: - raise OSError( - f"Failed to reset environment after {retry_count - 1} attempts.\n" - f"Last error ({type(e).__name__}): {e}" - ) from e - else: - return self.reset( - benchmark=benchmark, - action_space=action_space, - retry_count=retry_count + 1, - ) + reply = _call_with_retry( + self.service.stub.StartSession, start_session_request + ) self._session_id = reply.session_id self.observation.session_id = reply.session_id