forked from yandexdataschool/Practical_RL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
framebuffer.py
45 lines (40 loc) · 1.84 KB
/
framebuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
from gym.spaces.box import Box
from gym.core import Wrapper
class FrameBuffer(Wrapper):
def __init__(self, env, n_frames=4, dim_order='tensorflow'):
"""A gym wrapper that reshapes, crops and scales image into the desired shapes"""
super(FrameBuffer, self).__init__(env)
self.dim_order = dim_order
if dim_order == 'tensorflow':
height, width, n_channels = env.observation_space.shape
obs_shape = [height, width, n_channels * n_frames]
elif dim_order == 'pytorch':
n_channels, height, width = env.observation_space.shape
obs_shape = [n_channels * n_frames, height, width]
else:
raise ValueError(
'dim_order should be "tensorflow" or "pytorch", got {}'.format(dim_order))
self.observation_space = Box(0.0, 1.0, obs_shape)
self.framebuffer = np.zeros(obs_shape, 'float32')
def reset(self):
"""resets breakout, returns initial frames"""
self.framebuffer = np.zeros_like(self.framebuffer)
self.update_buffer(self.env.reset())
return self.framebuffer
def step(self, action):
"""plays breakout for 1 step, returns frame buffer"""
new_img, reward, done, info = self.env.step(action)
self.update_buffer(new_img)
return self.framebuffer, reward, done, info
def update_buffer(self, img):
if self.dim_order == 'tensorflow':
offset = self.env.observation_space.shape[-1]
axis = -1
cropped_framebuffer = self.framebuffer[:, :, :-offset]
elif self.dim_order == 'pytorch':
offset = self.env.observation_space.shape[0]
axis = 0
cropped_framebuffer = self.framebuffer[:-offset]
self.framebuffer = np.concatenate(
[img, cropped_framebuffer], axis=axis)