Skip to content

Commit

Permalink
Merge pull request #423 from ChrisCummins/reset-rety
Browse files Browse the repository at this point in the history
[env] Wrap all RPC calls in reset() in retry loop.
  • Loading branch information
ChrisCummins authored Sep 23, 2021
2 parents 03f0a3b + 653f65d commit 0f76d42
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,32 @@ def reset( # pylint: disable=arguments-differ
:raises TypeError: If no benchmark has been set, and the environment
does not have a default benchmark to select from.
"""

def _call_with_retry(stub_method, *args, **kwargs):
"""Call the given stub method. If it fails with an "acceptable"
error, abort this reset and retry.
"""
try:
return self.service(stub_method, *args, **kwargs)
except (ServiceError, ServiceTransportError, TimeoutError) as e:
# Abort and retry on error.
self.logger.warning("%s on reset(): %s", type(e).__name__, e)
if self.service:
self.service.close()
self.service = None

if retry_count >= self._connection_settings.init_max_attempts:
raise OSError(
f"Failed to reset environment after {retry_count - 1} attempts.\n"
f"Last error ({type(e).__name__}): {e}"
) from e
else:
return self.reset(
benchmark=benchmark,
action_space=action_space,
retry_count=retry_count + 1,
)

if not self._next_benchmark:
raise TypeError(
"No benchmark set. Set a benchmark using "
Expand All @@ -723,7 +749,7 @@ def reset( # pylint: disable=arguments-differ
# Stop an existing episode.
if self.in_episode:
self.logger.debug("Ending session %d", self._session_id)
self.service(
_call_with_retry(
self.service.stub.EndSession,
EndSessionRequest(session_id=self._session_id),
)
Expand Down Expand Up @@ -759,33 +785,19 @@ def reset( # pylint: disable=arguments-differ
)

try:
reply = self.service(self.service.stub.StartSession, start_session_request)
reply = _call_with_retry(
self.service.stub.StartSession, start_session_request
)
except FileNotFoundError:
# The benchmark was not found, so try adding it and repeating the
# request.
self.service(
self.service.stub.AddBenchmark,
AddBenchmarkRequest(benchmark=[self._benchmark_in_use.proto]),
)
reply = self.service(self.service.stub.StartSession, start_session_request)
except (ServiceError, ServiceTransportError, TimeoutError) as e:
# Abort and retry on error.
self.logger.warning("%s on reset(): %s", type(e).__name__, e)
if self.service:
self.service.close()
self.service = None

if retry_count >= self._connection_settings.init_max_attempts:
raise OSError(
f"Failed to reset environment after {retry_count - 1} attempts.\n"
f"Last error ({type(e).__name__}): {e}"
) from e
else:
return self.reset(
benchmark=benchmark,
action_space=action_space,
retry_count=retry_count + 1,
)
reply = _call_with_retry(
self.service.stub.StartSession, start_session_request
)

self._session_id = reply.session_id
self.observation.session_id = reply.session_id
Expand Down

0 comments on commit 0f76d42

Please sign in to comment.