-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
103 lines (78 loc) · 3.06 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
"""
This downloading script is modified from:
https://github.com/sorki/python-mnist
"""
import os
import struct
from array import array
import numpy as np
class MNIST(object):
def __init__(self, path='.'):
self.path = path
self.test_img_fname = 't10k-images.idx3-ubyte'
self.test_lbl_fname = 't10k-labels.idx1-ubyte'
self.train_img_fname = 'train-images.idx3-ubyte'
self.train_lbl_fname = 'train-labels.idx1-ubyte'
self.test_images = []
self.test_labels = []
self.train_images = []
self.train_labels = []
def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))
self.test_images = np.array(ims)
self.test_labels = np.array(labels)
return self.test_images, self.test_labels
def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))
self.train_images = np.array(ims)
self.train_labels = np.array(labels)
return self.train_images, self.train_labels
@classmethod
def load(cls, path_img, path_lbl):
with open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got {}'.format(magic))
labels = array("B", file.read())
with open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051,'
'got {}'.format(magic))
image_data = array("B", file.read())
images = []
for i in range(size):
images.append([0] * rows * cols)
for i in range(size):
images[i][:] = image_data[i * rows * cols:(i + 1) * rows * cols]
return images, labels
@classmethod
def display(cls, img, width=28, threshold=200):
render = ''
for i in range(len(img)):
if i % width == 0:
render += '\n'
if img[i] > threshold:
render += '@'
else:
render += '.'
return render
def load_data():
d = np.load('data/data.npz')
return d['x_train'], d['y_train'], d['x_test'], d['y_test']
if __name__ == '__main__':
if not os.path.exists('data/data.npz'):
data = MNIST('./data')
data.load_training()
data.load_testing()
print("save data in data/data.npz!")
np.savez('data/data.npz', x_train=data.train_images, y_train=data.train_labels,
x_test=data.test_images, y_test=data.test_labels)
x_train, y_train, x_test, y_test = load_data()
print x_test
print 'success!'