Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data reader for mnist #1325

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions demo/mnist/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os, sys
import struct
import numpy as np


class DataReader(object):
def __init__(self, data, labels, batch_size, is_shuffle=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果有多个data或labels为空的情况,这个接口可以复用么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个只是针对MNIST,不是通用的,其他的任务需要重新写。

assert data.shape[0] == labels.shape[0], (
'data.shape: %s labels.shape: %s' % (data.shape, labels.shape))
self.num_examples = data.shape[0]
self.data = data
self.labels = labels
self.batch_size = batch_size
self.is_shuffle = is_shuffle
self.index_in_epoch = 0

def __iter__(self):
def shuffle(self):
perm = np.arange(self.num_examples)
np.random.shuffle(perm)
self.data = self.data[perm]
self.labels = self.labels[perm]

if self.is_shuffle:
self.shuffle()
return self

def next(self):
if self.index_in_epoch >= self.num_examples:
self.index_in_epoch = 0
raise StopIteration

start = self.index_in_epoch
self.index_in_epoch += self.batch_size
end = min(self.index_in_epoch, self.num_examples)
return {'pixel': self.data[start:end], 'label': self.labels[start:end]}


def create_datasets(dir='./data/raw_data/'):
'''
数据download 和 load可以依据https://github.com/PaddlePaddle/Paddle/pull/872来简化
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有看到download的函数,感觉要是能自动download会方便用户使用一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#872 这个PR merge了之后会自动下载数据,这里就没有写。

'''

def load_data(filename, dir):
image = '-images-idx3-ubyte'
label = '-labels-idx1-ubyte'
if filename is 'train':
image_file = os.path.join(dir, filename + image)
label_file = os.path.join(dir, filename + label)
else:
image_file = os.path.join(dir, 't10k' + image)
label_file = os.path.join(dir, 't10k' + label)

with open(image_file, "rb") as f:
num_magic, n, num_row, num_col = struct.unpack(">IIII", f.read(16))
images = np.fromfile(f, 'ubyte', count=n * num_row * num_col).\
reshape(n, num_row * num_col).astype('float32')
images = images / 255.0 * 2.0 - 1.0
Copy link
Contributor

@helinwang helinwang Feb 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好奇images = images / 255.0 * 2.0 - 1.0这样把均值往0.0拉近一些,会比images = images / 255.0大概好多少?(比如说是98.55% -> 98.57%或者98.5%->98.9%),非常大概的估计就好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

images = images / 255.0 * 2.0 - 1.0 -> 是归到[-1, 1]
images = images / 255.0 ->[0, 1] 两者结果得做实验对比吧,感觉相差可能不会太大。

这里是继续采用了原始mnist demo的处理方式。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个取值范围就不一样吧,一个是[-1, 1],一个是[0, 1]。


with open(label_file, "rb") as fn:
num_magic, num_label = struct.unpack(">II", fn.read(8))
labels = np.fromfile(fn, 'ubyte', count=num_label).astype('int')

return images, labels

train_image, train_label = load_data('train', dir)
test_image, test_label = load_data('test', dir)

trainset = DataReader(train_image, train_label, 128, True)
testset = DataReader(test_image, test_label, 128, False)
return trainset, testset


def main():
train_data, test_data = create_datasets()
for data_batch in test_data:
print data_batch


if __name__ == "__main__":
main()