Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Feb 1, 2018
1 parent 412291f commit 18bba06
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/contrib/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def getdata(self):

def getlabel(self):
if self.getpad():
lshape = self._current_batch[0].shape
lshape = self._current_batch[1].shape
ret = nd.empty(shape=([self.batch_size] + list(lshape[1:])))
ret[:lshape[0]] = self._current_batch[1].astype(self.dtype)
return [ret]
Expand Down
32 changes: 19 additions & 13 deletions tests/python/unittest/test_contrib_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,25 @@
from mxnet.test_utils import *

def test_contrib_DataLoaderIter():
dataset = MNIST()
batch_size = 50
dataloader = DataLoader(dataset, batch_size)
test_iter = DataLoaderIter(dataloader)
batch = next(test_iter)
assert batch.data[0].shape == (batch_size, 28, 28, 1)
assert batch.label[0].shape == (batch_size,)
count = 0
test_iter.reset()
for batch in test_iter:
count += 1
expected = 60000 / batch_size
assert count == expected, "expected {} batches, given {}".format(expected, count)
def test_mnist_batches(batch_size, expected, last_batch='discard'):
dataset = MNIST(train=False)
dataloader = DataLoader(dataset, batch_size, last_batch=last_batch)
test_iter = DataLoaderIter(dataloader)
batch = next(test_iter)
assert batch.data[0].shape == (batch_size, 28, 28, 1)
assert batch.label[0].shape == (batch_size,)
count = 0
test_iter.reset()
for batch in test_iter:
count += 1
assert count == expected, "expected {} batches, given {}".format(expected, count)

num_examples = 10000
test_mnist_batches(50, num_examples // 50, 'discard')
test_mnist_batches(31, num_examples // 31, 'discard')
test_mnist_batches(31, num_examples // 31, 'rollover')
test_mnist_batches(31, num_examples // 31 + 1, 'keep')


if __name__ == "__main__":
test_contrib_DataLoaderIter()

0 comments on commit 18bba06

Please sign in to comment.