-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data reader for api #1326
Closed
Closed
Data reader for api #1326
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os, sys | ||
import struct | ||
import numpy as np | ||
import paddle.v2 as paddle | ||
|
||
|
||
def load_data(filename, dir='./data/raw_data/'): | ||
image = '-images-idx3-ubyte' | ||
label = '-labels-idx1-ubyte' | ||
if filename is 'train': | ||
image_file = os.path.join(dir, filename + image) | ||
label_file = os.path.join(dir, filename + label) | ||
else: | ||
image_file = os.path.join(dir, 't10k' + image) | ||
label_file = os.path.join(dir, 't10k' + label) | ||
|
||
with open(image_file, "rb") as f: | ||
num_magic, n, num_row, num_col = struct.unpack(">IIII", f.read(16)) | ||
images = np.fromfile(f, 'ubyte', count=n * num_row * num_col).\ | ||
reshape(n, num_row * num_col).astype('float32') | ||
images = images / 255.0 * 2.0 - 1.0 | ||
|
||
with open(label_file, "rb") as fn: | ||
num_magic, num_label = struct.unpack(">II", fn.read(8)) | ||
labels = np.fromfile(fn, 'ubyte', count=num_label).astype('int32') | ||
|
||
return images, labels | ||
|
||
|
||
def data(images, labels): | ||
for i in xrange(len(labels)): | ||
yield {"pixel": images[i, :], 'label': labels[i]} | ||
|
||
|
||
def main(): | ||
train_images, train_label = load_data('train') | ||
train_gen = data(train_images, train_label) | ||
train_data = paddle.data.CacheAllDataPool(train_gen, 128, | ||
['pixel', 'label']) | ||
|
||
test_images, test_label = load_data('test') | ||
test_gen = data(test_images[0:128], test_label[0:128]) | ||
test_data = paddle.data.CacheAllDataPool(test_gen, 128, ['pixel', 'label'], | ||
False) | ||
|
||
for data_batch in test_data: | ||
print data_batch | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import collections | ||
import random | ||
|
||
__all__ = [ | ||
'IDataPool', | ||
'CacheAllDataPool', | ||
] | ||
|
||
|
||
class IDataPool(object): | ||
""" | ||
Interface of DataPool, but note that Python is using Duck-Typing, it is not | ||
necessary to inherit this interface. | ||
|
||
NOTE: For Paddle developer, NEVER CHECK isinstance(obj, IDataPool). | ||
|
||
Basically contains two method, | ||
|
||
* next(): User should return the next batch of data in pool. raise | ||
StopIteration if there is no more data in pool. | ||
|
||
* reset(): Reset the data pool to initial status. | ||
|
||
The basic usage of this api is as same as normal Python iterator, like | ||
|
||
.. code-block:: python | ||
|
||
pool = DataPool() | ||
|
||
for batch in pool: | ||
process_batch(batch) | ||
|
||
|
||
NOTE: The Data Pool API is not thread-safe. | ||
""" | ||
|
||
def __iter__(self): | ||
self.reset() | ||
return self | ||
|
||
def next(self): | ||
raise NotImplementedError() | ||
|
||
def __next__(self): | ||
return self.next() | ||
|
||
def reset(self): | ||
raise NotImplementedError() | ||
|
||
|
||
def input_order_mapper(iterable, input_order): | ||
assert isinstance(input_order, collections.Sequence) | ||
for each_input_name in input_order: | ||
assert isinstance(each_input_name, basestring) | ||
|
||
tmp = [None] * len(input_order) | ||
for each_item in iterable: | ||
for i in xrange(len(input_order)): | ||
tmp[i] = each_item[input_order[i]] | ||
yield tmp | ||
|
||
|
||
class CacheAllDataPool(IDataPool): | ||
""" | ||
load all samples in memory. | ||
""" | ||
|
||
def __init__(self, iterable, batch_size, input_order, shuffle=True): | ||
self.__pool__ = list( | ||
input_order_mapper( | ||
iterable=iterable, input_order=input_order)) | ||
self.__batch_size__ = batch_size | ||
self.__shuffle__ = shuffle | ||
self.__idx__ = 0 | ||
|
||
def reset(self): | ||
self.__idx__ = 0 | ||
if self.__shuffle__: | ||
random.shuffle(self.__pool__) | ||
|
||
def next(self): | ||
if self.__idx__ >= len(self.__pool__): | ||
raise StopIteration() | ||
|
||
begin = self.__idx__ | ||
end = min(self.__idx__ + self.__batch_size__, len(self.__pool__)) | ||
self.__idx__ = end | ||
return self.__pool__[begin:end] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看网上的教程用的是
raise StopIteration
,raise StopIteration()
也行吗?