-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDataBatch.py
47 lines (39 loc) · 1.16 KB
/
DataBatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
class DataBatch(object):
def __init__(self, x, y, shuffle=False):
assert len(x) == len(y)
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
self._size = len(x)
if shuffle:
index = np.random.permutation(self._size)
x = x[index]
y = y[index]
self._X = x
self._y = y
self._shuffle = shuffle
self._index = 0
def next_batch(self, batch_size):
start = self._index
self._index += batch_size
if self._index > self._size:
start = 0
if self._shuffle:
index = np.random.permutation(self._size)
self._X = self._X[index]
self._y = self._y[index]
self._index = start + batch_size
assert batch_size <= self._size
end = self._index
return self._X[start: end], self._y[start: end]
@property
def x(self):
return self._X
@property
def y(self):
return self._y
@property
def size(self):
return self._size