-
Notifications
You must be signed in to change notification settings - Fork 0
/
experience_replay.py
39 lines (33 loc) · 1.37 KB
/
experience_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
class ReplayMemory(object):
def __init__(self, config):
"""
input : enviroment instanse , max size of replay memory
"""
self.max_size = config.rep_max_size
self.shape = config.state_shape
self.states = np.zeros(([self.max_size] + self.shape), dtype=np.uint8)
self.next_states = np.zeros(([self.max_size] + self.shape), dtype=np.uint8)
self.actions = np.zeros(self.max_size, dtype=np.int8)
self.rewards = np.zeros(self.max_size, dtype=np.int8)
self.done = np.zeros(self.max_size, dtype=np.bool)
self.idx = 0
self.cnt = 0
def push(self, state, next_state, action, reward, done):
"""
add transition to memory
"""
assert state.shape == tuple(self.shape)
self.states[self.idx] = state
self.next_states[self.idx] = next_state
self.actions[self.idx] = action
self.rewards[self.idx] = reward
self.done[self.idx] = done
self.idx = (self.idx + 1) % self.max_size
self.cnt = min(self.cnt + 1, self.max_size)
def get_batch(self, batch_size=64):
"""
returns : states,actions,rewards,done
"""
idxs = np.random.choice(range(self.cnt), batch_size)
return self.states[idxs], self.next_states[idxs], self.actions[idxs], self.rewards[idxs], self.done[idxs]