Skip to content

Commit

Permalink
fix RandomSampler & BatchSampler. test=develop (#26559)
Browse files Browse the repository at this point in the history
* fix RandomSampler & BatchSampler. test=develop
  • Loading branch information
heavengate authored Aug 24, 2020
1 parent d6e888c commit dd3df69
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
13 changes: 7 additions & 6 deletions python/paddle/fluid/dataloader/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division

import numpy as np
from .sampler import Sampler, SequenceSampler
from .sampler import Sampler, SequenceSampler, RandomSampler
from .dataset import Dataset, IterableDataset

__all__ = ["BatchSampler"]
Expand Down Expand Up @@ -86,7 +86,6 @@ def __len__(self):
# init with sampler
sampler = RandomSampler(RandomDataset(100))
bs = BatchSampler(sampler=sampler,
shuffle=True,
batch_size=8,
drop_last=True)
Expand Down Expand Up @@ -118,14 +117,16 @@ def __init__(self,
"dataset should not be a paddle.io.IterableDataset"
assert sampler is None, \
"should not set both dataset and sampler"
self.sampler = SequenceSampler(dataset)
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value, but got {}".format(type(shuffle))
if shuffle:
self.sampler = RandomSampler(dataset)
else:
self.sampler = SequenceSampler(dataset)

assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size)
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value, but got {}".format(type(shuffle))
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean value, but got {}".format(type(drop_last))
self.drop_last = drop_last
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __getitem__(self, idx):
def __len__(self):
return self.num_samples
sampler = RandomSampler(data_souce=RandomDataset(100))
sampler = RandomSampler(data_source=RandomDataset(100))
for index in sampler:
print(index)
Expand Down Expand Up @@ -216,7 +216,11 @@ def num_samples(self):
def __iter__(self):
n = len(self.data_source)
if self.generator:
for index in self.generator:
for i in range(self.num_samples):
try:
index = next(self.generator)
except StopIteration:
return
yield index
else:
if self.replacement:
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ def test_with_generator(self):
rets.append(i)
assert tuple(sorted(rets)) == tuple(range(0, 60))

def test_with_generator_num_samples(self):
dataset = RandomDataset(100, 10)
generator = iter(range(0, 60))
sampler = RandomSampler(
dataset, generator=generator, num_samples=50, replacement=True)
assert len(sampler) == 50

rets = []
for i in iter(sampler):
rets.append(i)
assert tuple(sorted(rets)) == tuple(range(0, 50))


class TestBatchSampler(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit dd3df69

Please sign in to comment.