-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory.py
36 lines (28 loc) · 1.03 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from config import *
from collections import deque
import numpy as np
import random
class ReplayMemory(object):
def __init__(self):
self.memory = deque(maxlen=Memory_capacity)
def push(self, history, action, reward, done):
self.memory.append((history, action, reward, done))
def sample_mini_batch(self, frame):
mini_batch = []
if frame >= Memory_capacity:
sample_range = Memory_capacity
else:
sample_range = frame
# history size
sample_range -= (HISTORY_SIZE + 1)
print("Sample range: ", sample_range)
idx_sample = random.sample(range(sample_range), batch_size)
for i in idx_sample:
sample = []
for j in range(HISTORY_SIZE + 1):
sample.append(self.memory[i + j])
sample = np.array(sample)
mini_batch.append((np.stack(sample[:, 0], axis=0), sample[3, 1], sample[3, 2], sample[3, 3]))
return mini_batch
def __len__(self):
return len(self.memory)