-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathenvs.py
117 lines (98 loc) · 3.61 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import gym
from gym import spaces
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env import VecEnv
from multiprocessing import Process, Pipe
# cf https://github.com/openai/baselines
def make_env(env_name, rank, seed):
env = make_atari(env_name)
env.seed(seed + rank)
env = wrap_deepmind(env)
return env
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.action_space, env.observation_space))
elif cmd == 'render':
env.render()
else:
raise NotImplementedError
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class RenderSubprocVecEnv(VecEnv):
def __init__(self, env_fns, render_interval):
""" Minor addition to SubprocVecEnv, automatically renders environments
envs: list of gym environments to run in subprocesses
"""
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
self.action_space, self.observation_space = self.remotes[0].recv()
self.render_interval = render_interval
self.render_timer = 0
def step(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
results = [remote.recv() for remote in self.remotes]
obs, rews, dones, infos = zip(*results)
self.render_timer += 1
if self.render_timer == self.render_interval:
for remote in self.remotes:
remote.send(('render', None))
self.render_timer = 0
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
if self.closed:
return
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
@property
def num_envs(self):
return len(self.remotes)