-
Notifications
You must be signed in to change notification settings - Fork 12
/
preprocessing.py
353 lines (278 loc) · 12.3 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from collections import deque
import cv2
import numpy as np
from gym import spaces
from gym.core import Wrapper, ObservationWrapper, RewardWrapper
from gym.spaces import Box
"""
Observation preprocessing and environment tweaks.
Section 8 ("Experimental Setup") of the paper says:
"The Atari experiments used the same input preprocessing as (Mnih et al., 2015)
and an action repeat of 4."
'Mnih et al., 2015' is 'Human-level control through deep reinforcement learning'.
The relevant parts of that paper's Methods section are summarised below.
# Observation preprocessing:
'Preprocessing':
- "First, to encode a single frame we take the maximum value for each pixel colour value over the
frame being encoded and the previous frame. This was necessary to remove flickering that is
present in games where some objects appear only in even frames while other objects appear only
in odd frames, an artefact caused by the limited number of sprites Atari 2600 can display at
once."
- "Second, we then extract the Y channel, also known as luminance, from the RGB frame and rescale
it to 84 x 84."
- "The function phi from algorithm 1 described below applies this preprocessing to the m most
recent frames and stacks them to produce the input to the Q-function, in which m = 4, although
the algorithm is robust to different values of m (for example, 3 or 5)."
'Training details':
- "Following previous approaches to playing Atari 2600 games, we also use a simple frame-skipping
technique. More precisely, the agent sees and selects actions on every kth frame instead of
every frame, and its last action is repeated on skipped frames. Because running the emulator
forward for one step requires much less computation than having the agent select an action,
this technique allows the agent to play roughly k times more games without significantly
increasing the runtime. We use k = 4 for all games."
There's some ambiguity about what order to apply these steps in. I think the right order should be:
1. Max over subsequent frames
So - observation 0: max. over frames 0 and 1
observation 1: max. over frames 1 and 2
etc.
2. Extract luminance and scale
3. Skip frames
So - observation 0: max. over frames 0 and 1
observation 1: max. over frames 4 and 5
etc.
4. Stack frames
So - frame stack 0: max. over frames 0 and 1
max. over frames 4 and 5
max. over frames 8 and 9
max. over frames 12 and 13
frame stack 1: max. over frames 4 and 5
max. over frames 8 and 9
max. over frames 12 and 13
max. over frames 16 and 17
The main ambiguity is whether frame skipping or frame stacking should be done first.
Above we've assumed frame skipping should be done first. If we did frame stacking first, we would
only look at every 4th frame stack, giving:
- Frame stack 0: max. over frames 0 and 1
max. over frames 1 and 2
max. over frames 2 and 3
max. over frames 3 and 4
- Frame stack 4: max. over frames 4 and 5
max. over frames 5 and 6
max. over frames 6 and 7
max. over frames 7 and 8
Note that there's a big difference: frame skip then frame stack gives the agent much less temporal
scope than frame stack then frame skip. In the former, the agent has access to 12 frames' worth of
observations, whereas in the latter, only 4 frames' worth.
## Environment tweaks
'Training details':
- "As the scale of scores varies greatly from game to game, we clipped all positive rewards at 1 and
all negative rewards at -1, leaving 0 rewards unchanged."
- "For games where there is a life counter, the Atari 2600 emulator also sends the number of lives
left in thegame, which is then used to mark the end of an episode during training."
'Evaluation procedure':
- "The trained agents were evaluated by playing each game 30 times for up to 5 min each time with
different initial random conditions ('no-op'; see Extended Data Table 1)."
Extended Data Table 1 lists "no-op max" as 30 (set in params.py).
We implement all these steps using a modular set of wrappers, heavily inspired by Baselines'
atari_wrappers.py (https://git.io/vhWWG).
"""
def get_noop_action_index(env):
action_meanings = env.unwrapped.get_action_meanings()
try:
noop_action_index = action_meanings.index('NOOP')
return noop_action_index
except ValueError:
raise Exception("Unsure about environment's no-op action")
class MaxWrapper(Wrapper):
"""
Take maximum pixel values over pairs of frames.
"""
def __init__(self, env):
Wrapper.__init__(self, env)
self.frame_pairs = deque(maxlen=2)
def reset(self):
obs = self.env.reset()
self.frame_pairs.append(obs)
# The first frame returned should be the maximum of frames 0 and 1.
# We get frame 0 from env.reset(). For frame 1, we take a no-op action.
noop_action_index = get_noop_action_index(self.env)
obs, _, done, _ = self.env.step(noop_action_index)
if done:
raise Exception("Environment signalled done during initial frame "
"maxing")
self.frame_pairs.append(obs)
return np.max(self.frame_pairs, axis=0)
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.frame_pairs.append(obs)
obs_maxed = np.max(self.frame_pairs, axis=0)
return obs_maxed, reward, done, info
class ExtractLuminanceAndScaleWrapper(ObservationWrapper):
"""
Convert observations from colour to grayscale, then scale to 84 x 84
"""
def __init__(self, env):
ObservationWrapper.__init__(self, env)
# Important so that gym's play.py picks up the right resolution
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)
def observation(self, obs):
obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
# Bilinear interpolation
obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_LINEAR)
return obs
class FrameStackWrapper(Wrapper):
"""
Stack the most recent 4 frames together.
"""
def __init__(self, env):
Wrapper.__init__(self, env)
self.frame_stack = deque(maxlen=4)
low = np.tile(env.observation_space.low[..., np.newaxis], 4)
high = np.tile(env.observation_space.high[..., np.newaxis], 4)
dtype = env.observation_space.dtype
self.observation_space = Box(low=low, high=high, dtype=dtype)
def _get_obs(self):
obs = np.array(self.frame_stack)
# Switch from (4, 84, 84) to (84, 84, 4), so that we have the right order for inputting
# directly into the convnet with the default channels_last
obs = np.moveaxis(obs, 0, -1)
return obs
def reset(self):
obs = self.env.reset()
self.frame_stack.append(obs)
# The first observation returned should be a stack of observations 0 through 3. We get
# observation 0 from env.reset(). For the rest, we take no-op actions.
noop_action_index = get_noop_action_index(self.env)
for _ in range(3):
obs, _, done, _ = self.env.step(noop_action_index)
if done:
raise Exception("Environment signalled done during initial "
"frame stack")
self.frame_stack.append(obs)
return self._get_obs()
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.frame_stack.append(obs)
return self._get_obs(), reward, done, info
class FrameSkipWrapper(Wrapper):
"""
Repeat the chosen action for 4 frames, only returning the last frame.
"""
def reset(self):
return self.env.reset()
def step(self, action):
reward_sum = 0
for _ in range(4):
obs, reward, done, info = self.env.step(action)
reward_sum += reward
if done:
break
return obs, reward_sum, done, info
class RandomStartWrapper(Wrapper):
"""
Start each episode with a random number of no-ops.
"""
def __init__(self, env, max_n_noops):
Wrapper.__init__(self, env)
self.max_n_noops = max_n_noops
def step(self, action):
return self.env.step(action)
def reset(self):
obs = self.env.reset()
n_noops = np.random.randint(low=0, high=self.max_n_noops + 1)
noop_action_index = get_noop_action_index(self.env)
for _ in range(n_noops):
obs, _, done, _ = self.env.step(noop_action_index)
if done:
raise Exception("Environment signalled done during initial no-ops")
return obs
class NormalizeObservationsWrapper(ObservationWrapper):
"""
Normalize observations to range [0, 1].
"""
def __init__(self, env):
ObservationWrapper.__init__(self, env)
self.observation_space = spaces.Box(low=0.0, high=1.0, shape=env.observation_space.shape,
dtype=np.float32)
def observation(self, obs):
return obs / 255.0
class ClipRewardsWrapper(RewardWrapper):
"""
Clip rewards to range [-1, +1].
"""
def reward(self, reward):
return np.clip(reward, -1, +1)
class EndEpisodeOnLifeLossWrapper(Wrapper):
"""
Send 'episode done' when life lost. (Baselines' atari_wrappers.py claims that this helps with
value estimation. I guess it makes it clear that only actions since the last loss of life
contributed significantly to any rewards in the present.)
"""
def __init__(self, env):
Wrapper.__init__(self, env)
self.done_because_life_lost = False
self.reset_obs = None
def step(self, action):
lives_before = self.env.unwrapped.ale.lives()
obs, reward, done, info = self.env.step(action)
lives_after = self.env.unwrapped.ale.lives()
if done:
self.done_because_life_lost = False
elif lives_after < lives_before:
self.done_because_life_lost = True
self.reset_obs = obs
done = True
return obs, reward, done, info
def reset(self):
assert self.done_because_life_lost is not None
# If we sent the 'episode done' signal after a loss of a life, then we'll probably get a
# reset signal next. But we shouldn't actually reset! We should just keep on playing until
# the /real/ end-of-episode.
if self.done_because_life_lost:
self.done_because_life_lost = None
return self.reset_obs
else:
return self.env.reset()
def generic_preprocess(env, max_n_noops, clip_rewards=True):
"""
Apply the full sequence of preprocessing steps.
"""
env = RandomStartWrapper(env, max_n_noops)
env = MaxWrapper(env)
env = ExtractLuminanceAndScaleWrapper(env)
env = NormalizeObservationsWrapper(env)
env = FrameSkipWrapper(env)
env = FrameStackWrapper(env)
env = EndEpisodeOnLifeLossWrapper(env)
if clip_rewards:
env = ClipRewardsWrapper(env)
return env
"""
We also have a wrapper to extract hand-crafted features from Pong for early
debug testing.
"""
class PongFeaturesWrapper(ObservationWrapper):
"""
Manually extract the Pong game area, setting paddles/ball to 1.0 and the background to 0.0.
"""
def __init__(self, env):
ObservationWrapper.__init__(self, env)
self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(84, 84), dtype=np.float32)
def observation(self, obs):
"""
Based on Andrej Karpathy's code for Pong with policy gradients:
https://gist.github.com/karpathy/a4166c7fe253700972fcbc77e4ea32c5
"""
obs = np.mean(obs, axis=2) / 255.0 # Convert to [0, 1] grayscale
obs = obs[34:194] # Extract game area
obs = obs[::2, ::2] # Downsample by a factor of 2
obs = np.pad(obs, pad_width=2, mode='constant') # Pad to 84x84
obs[obs <= 0.4] = 0 # Erase background
obs[obs > 0.4] = 1 # Set balls, paddles to 1
return obs
def pong_preprocess(env, max_n_noops):
env = RandomStartWrapper(env, max_n_noops)
env = PongFeaturesWrapper(env)
env = FrameSkipWrapper(env)
env = FrameStackWrapper(env)
return env