-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist.py
35 lines (27 loc) · 1.54 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import numpy as np
from sklearn.model_selection import train_test_split
def load_mnist(batch_size, samples_per_epoch=None, is_training=True, use_val_only=False):
path = os.path.join('data', 'mnist')
if is_training:
train_imgs = open(os.path.join(path, 'train-images-idx3-ubyte'))
X = np.fromfile(file=train_imgs, dtype=np.uint8)
X = X[16:].reshape((60000, 28, 28, 1)).astype(np.float32) / 255
train_labs = open(os.path.join(path, 'train-labels-idx1-ubyte'))
Y = np.fromfile(file=train_labs, dtype=np.uint8)
Y = Y[8:].reshape((60000)).astype(np.int32)
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=5000)
num_train_batches = samples_per_epoch // batch_size if samples_per_epoch else len(Y_train) // batch_size
num_val_batches = len(Y_val) // batch_size
if use_val_only:
return [], X_val, [], Y_val, num_train_batches, num_val_batches
return X_train, X_val, Y_train, Y_val, num_train_batches, num_val_batches
else:
test_imgs = open(os.path.join(path, 't10k-images-idx3-ubyte'))
X_test = np.fromfile(file=test_imgs, dtype=np.uint8)
X_test = X_test[16:].reshape((10000, 28, 28, 1)).astype(np.float32) / 255
test_labs = open(os.path.join(path, 't10k-labels-idx1-ubyte'))
Y_test = np.fromfile(file=test_labs, dtype=np.uint8)
Y_test = Y_test[8:].reshape((10000)).astype(np.int32)
num_test_batches = len(Y_test) // batch_size
return X_test, Y_test, num_test_batches