Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SubprocVecEnv Thread Safe #218

Merged
merged 11 commits into from
Mar 10, 2019
7 changes: 3 additions & 4 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️

.. warning::

When using ``SubprocVecEnv``, Windows users must wrap the code
in an ``if __name__=="__main__":``.
See `stackoverflow question <https://stackoverflow.com/questions/24374288/where-to-put-freeze-support-in-a-python-script>`_
for more information about multiprocessing on Windows using python.
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":``
if using the ``forkserver`` or ``spawn`` start method (the default). For more information, see Python's
`multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.


DummyVecEnv
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ For download links, please look at `Github release page <https://github.com/hill
Pre-release 2.4.2a (WIP)
------------------------

- added suport for Dict spaces in DummyVecEnv and SubprocVecEnv. (@AdamGleave)
- added support for Dict spaces in DummyVecEnv and SubprocVecEnv. (@AdamGleave)
- made SubprocVecEnv thread-safe by default; support arbitrary multiprocessing start methods. (@AdamGleave)

Release 2.4.1 (2019-02-11)
--------------------------
Expand Down
34 changes: 25 additions & 9 deletions stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import multiprocessing
from collections import OrderedDict
from multiprocessing import Process, Pipe

import gym
import numpy as np
Expand Down Expand Up @@ -47,20 +47,36 @@ class SubprocVecEnv(VecEnv):
Creates a multiprocess vectorized wrapper for multiple environments

:param env_fns: ([Gym Environment]) Environments to run in subprocesses
:param start_method: (str) method used to start the subprocesses.
Must be one of the methods returned by multiprocessing.get_all_start_methods().
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
Both 'forkserver' and 'spawn' are thread-safe, which is important when TensorFlow
sessions or other non thread-safe libraries are used in the parent (see issue #217).
However, compared to 'fork' they incur a small start-up cost and have restrictions on
global variables. For more information, see the multiprocessing documentation.
"""

def __init__(self, env_fns):
def __init__(self, env_fns, start_method=None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(n_envs)])
self.processes = [Process(target=_worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for process in self.processes:
process.daemon = True # if the main process crashes, we should not cause things to hang

if start_method is None:
# Use thread safe method, see issue #217.
# forkserver faster than spawn but not always available.
forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
start_method = 'forkserver' if forkserver_available else 'spawn'
ctx = multiprocessing.get_context(start_method)

self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True)
process.start()
for remote in self.work_remotes:
remote.close()
self.processes.append(process)
work_remote.close()

self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
Expand Down
21 changes: 21 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import functools
import itertools
import multiprocessing
import pytest
import gym
import numpy as np
Expand Down Expand Up @@ -94,6 +96,8 @@ def make_env():
assert setattr_result == [None for _ in range(2)]
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]

vec_env.close()


SPACES = collections.OrderedDict([
('discrete', gym.spaces.Discrete(2)),
Expand All @@ -116,6 +120,7 @@ def make_env():
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
obs, _rews, dones, _infos = vec_env.step(actions)
obs_assert(obs)
vec_env.close()


def check_vecenv_obs(obs, space):
Expand Down Expand Up @@ -170,3 +175,19 @@ def obs_assert(obs):
check_vecenv_obs(values, inner_space)

return check_vecenv_spaces(vec_env_class, space, obs_assert)


def test_subproc_start_method():
start_methods = [None] + multiprocessing.get_all_start_methods()
space = gym.spaces.Discrete(2)

def obs_assert(obs):
return check_vecenv_obs(obs, space)

for start_method in start_methods:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
check_vecenv_spaces(vec_env_class, space, obs_assert)

with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
check_vecenv_spaces(vec_env_class, space, obs_assert)