Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restart crashed Unity environments #5553

Merged
merged 7 commits into from
Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ and this project adheres to
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Set gym version in gym-unity to gym release 0.20.0
- Changed default behavior to restart crashed Unity environments rather than exiting.
hvpeteet marked this conversation as resolved.
Show resolved Hide resolved
- Rate & lifetime limits on this are configurable via 3 new yaml options
1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10]
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
### Bug Fixes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
Expand Down
3 changes: 3 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ env_settings:
base_port: 5005
num_envs: 1
seed: -1
max_lifetime_restarts: 10
restarts_rate_limit_n: 1
restarts_rate_limit_period_s: 60
```

#### Engine settings
Expand Down
5 changes: 4 additions & 1 deletion ml-agents-envs/mlagents_envs/rpc_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def create_server(self):

try:
# Establish communication grpc
self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
self.server = grpc.server(
thread_pool=ThreadPoolExecutor(max_workers=10),
options=(("grpc.so_reuseport", 1),),
)
self.unity_to_external = UnityToExternalServicerImplementation()
add_UnityToExternalProtoServicer_to_server(
self.unity_to_external, self.server
Expand Down
20 changes: 20 additions & 0 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ def _create_parser() -> argparse.ArgumentParser:
"passed to the executable.",
action=DetectDefault,
)
argparser.add_argument(
"--max-lifetime-restarts",
default=10,
help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. "
"Can be set to -1 if no limit is desired.",
action=DetectDefault,
)
argparser.add_argument(
"--restarts-rate-limit-n",
default=1,
help="The maximum number of times a single Unity executable can crash over a period of time (period set in "
"restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.",
action=DetectDefault,
)
argparser.add_argument(
"--restarts-rate-limit-period-s",
default=60,
help="The period of time --restarts-rate-limit-n applies to.",
action=DetectDefault,
)
argparser.add_argument(
"--torch",
default=False,
Expand Down
5 changes: 5 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,11 @@ class EnvironmentSettings:
base_port: int = parser.get_default("base_port")
num_envs: int = attr.ib(default=parser.get_default("num_envs"))
seed: int = parser.get_default("seed")
max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts")
restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n")
restarts_rate_limit_period_s: int = parser.get_default(
"restarts_rate_limit_period_s"
)

@num_envs.validator
def validate_num_envs(self, attribute, value):
Expand Down
122 changes: 117 additions & 5 deletions ml-agents/mlagents/trainers/subprocess_env_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle
import enum
Expand Down Expand Up @@ -251,6 +252,14 @@ def __init__(
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
self.workers_alive = 0
self.env_factory = env_factory
self.run_options = run_options
self.env_parameters: Optional[Dict] = None
# Each worker is correlated with a list of times they restarted within the last time period.
self.recent_restart_timestamps: List[List[datetime.datetime]] = [
[] for _ in range(n_env)
]
self.restart_counts: List[int] = [0] * n_env
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(
Expand Down Expand Up @@ -293,6 +302,105 @@ def _queue_steps(self) -> None:
env_worker.send(EnvironmentCommand.STEP, env_action_info)
env_worker.waiting = True

def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None:
if first_failure.cmd != EnvironmentCommand.ENV_EXITED:
return
# Drain the step queue to make sure all workers are paused and we have found all concurrent errors.
# Pausing all training is needed since we need to reset all pending training steps as they could be corrupted.
other_failures: Dict[int, Exception] = self._drain_step_queue()
# TODO: Once we use python 3.9 switch to using the | operator to combine dicts.
failures: Dict[int, Exception] = {
**{first_failure.worker_id: first_failure.payload},
**other_failures,
}
for worker_id, ex in failures.items():
self._assert_worker_can_restart(worker_id, ex)
logger.warning(f"Restarting worker[{worker_id}] after '{ex}'")
self.recent_restart_timestamps[worker_id].append(datetime.datetime.now())
self.restart_counts[worker_id] += 1
self.env_workers[worker_id] = self.create_worker(
worker_id, self.step_queue, self.env_factory, self.run_options
)
# The restarts were successful, clear all the existing training trajectories so we don't use corrupted or
# outdated data.
self.reset(self.env_parameters)

def _drain_step_queue(self) -> Dict[int, Exception]:
"""
Drains all steps out of the step queue and returns all exceptions from crashed workers.
This will effectively pause all workers so that they won't do anything until _queue_steps is called.
"""
all_failures = {}
workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting}
deadline = datetime.datetime.now() + datetime.timedelta(minutes=1)
while workers_still_pending and deadline > datetime.datetime.now():
try:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
workers_still_pending.add(step.worker_id)
all_failures[step.worker_id] = step.payload
else:
workers_still_pending.remove(step.worker_id)
self.env_workers[step.worker_id].waiting = False
except EmptyQueueException:
pass
if deadline < datetime.datetime.now():
still_waiting = {w.worker_id for w in self.env_workers if w.waiting}
raise TimeoutError(f"Workers {still_waiting} stuck in waiting state")
return all_failures

def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None:
"""
Checks if we can recover from an exception from a worker.
If the restart limit is exceeded it will raise a UnityCommunicationException.
If the exception is not recoverable it re-raises the exception.
"""
if (
isinstance(exception, UnityCommunicationException)
or isinstance(exception, UnityTimeOutException)
or isinstance(exception, UnityEnvironmentException)
or isinstance(exception, UnityCommunicatorStoppedException)
):
if self._worker_has_restart_quota(worker_id):
return
else:
logger.error(
f"Worker {worker_id} exceeded the allowed number of restarts."
)
raise exception
raise exception

def _worker_has_restart_quota(self, worker_id: int) -> bool:
self._drop_old_restart_timestamps(worker_id)
max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts
max_limit_check = (
max_lifetime_restarts == -1
or self.restart_counts[worker_id] < max_lifetime_restarts
)

rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n
rate_limit_check = (
rate_limit_n == -1
or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n
)

return rate_limit_check and max_limit_check

def _drop_old_restart_timestamps(self, worker_id: int) -> None:
"""
Drops environment restart timestamps that are outside of the current window.
"""

def _filter(t: datetime.datetime) -> bool:
return t > datetime.datetime.now() - datetime.timedelta(
seconds=self.run_options.env_settings.restarts_rate_limit_period_s
)

self.recent_restart_timestamps[worker_id] = list(
filter(_filter, self.recent_restart_timestamps[worker_id])
)

def _step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
Expand All @@ -306,15 +414,18 @@ def _step(self) -> List[EnvironmentStep]:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
env_exception: Exception = step.payload
raise env_exception
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
# If even one env exits try to restart all envs that failed.
self._restart_failed_workers(step)
# Clear state and restart this function.
worker_steps.clear()
step_workers.clear()
self._queue_steps()
elif step.worker_id not in step_workers:
self.env_workers[step.worker_id].waiting = False
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass

step_infos = self._postprocess_steps(worker_steps)
return step_infos

Expand All @@ -339,6 +450,7 @@ def set_env_parameters(self, config: Dict = None) -> None:
EnvironmentParametersSidehannel for each worker.
:param config: Dict of environment parameter keys and values
"""
self.env_parameters = config
for ew in self.env_workers:
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)

Expand Down
65 changes: 63 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import mock
from unittest.mock import Mock, MagicMock
from unittest.mock import Mock, MagicMock, call, ANY
import unittest
import pytest
from queue import Empty as EmptyQueue
Expand All @@ -14,7 +14,10 @@
from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
)
from mlagents.trainers.tests.simple_test_envs import (
SimpleEnvironment,
UnexpectedExceptionEnvironment,
Expand Down Expand Up @@ -153,6 +156,64 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
manager.env_workers[1].previous_step,
]

@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_crashed_env_restarts(self, mock_create_worker):
crashing_worker = MockEnvWorker(
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
)
restarting_worker = MockEnvWorker(
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
)
healthy_worker = MockEnvWorker(
1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
)
mock_create_worker.side_effect = [
crashing_worker,
healthy_worker,
restarting_worker,
]
manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse(
EnvironmentCommand.ENV_EXITED,
0,
UnityCommunicationException("Test msg"),
),
EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
EmptyQueue(),
EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
EmptyQueue(),
]
step_mock = Mock()
last_steps = [Mock(), Mock(), Mock()]
assert crashing_worker is manager.env_workers[0]
assert healthy_worker is manager.env_workers[1]
crashing_worker.previous_step = last_steps[0]
crashing_worker.waiting = True
healthy_worker.previous_step = last_steps[1]
healthy_worker.waiting = True
manager._take_step = Mock(return_value=step_mock)
manager._step()
healthy_worker.send.assert_has_calls(
[
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
call(EnvironmentCommand.RESET, ANY),
call(EnvironmentCommand.STEP, ANY),
]
)
restarting_worker.send.assert_has_calls(
[
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
call(EnvironmentCommand.RESET, ANY),
call(EnvironmentCommand.STEP, ANY),
]
)

@mock.patch("mlagents.trainers.subprocess_env_manager.SubprocessEnvManager._step")
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.training_behaviors",
Expand Down