From 893eb2ff6cbe1db230fa74f8dee655a67ccdf4e6 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 22 Aug 2020 06:33:30 +0000 Subject: [PATCH 1/4] fix RandomSampler & BatchSampler. test=develop --- python/paddle/fluid/dataloader/batch_sampler.py | 10 ++++++---- python/paddle/fluid/dataloader/sampler.py | 6 +++++- .../fluid/tests/unittests/test_batch_sampler.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 8043237c0d97d..294101e51f7f8 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -118,14 +118,16 @@ def __init__(self, "dataset should not be a paddle.io.IterableDataset" assert sampler is None, \ "should not set both dataset and sampler" - self.sampler = SequenceSampler(dataset) + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value, but got {}".format(type(shuffle)) + if shuffle: + self.sampler = RandomSampler(dataset) + else: + self.sampler = SequenceSampler(dataset) assert isinstance(batch_size, int) and batch_size > 0, \ "batch_size should be a positive integer, but got {}".format(batch_size) self.batch_size = batch_size - assert isinstance(shuffle, bool), \ - "shuffle should be a boolean value, but got {}".format(type(shuffle)) - self.shuffle = shuffle assert isinstance(drop_last, bool), \ "drop_last should be a boolean value, but got {}".format(type(drop_last)) self.drop_last = drop_last diff --git a/python/paddle/fluid/dataloader/sampler.py b/python/paddle/fluid/dataloader/sampler.py index d2f3231cc6b12..06e8077fe0646 100644 --- a/python/paddle/fluid/dataloader/sampler.py +++ b/python/paddle/fluid/dataloader/sampler.py @@ -216,7 +216,11 @@ def num_samples(self): def __iter__(self): n = len(self.data_source) if self.generator: - for index in self.generator: + for i in range(self.num_samples): + try: + index = next(self.generator) + except StopIteration: + return yield index else: if self.replacement: diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 2e2a6144fd011..0424fc3f88f60 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -88,6 +88,20 @@ def test_with_generator(self): rets.append(i) assert tuple(sorted(rets)) == tuple(range(0, 60)) + def test_with_generator_num_samples(self): + dataset = RandomDataset(100, 10) + generator = iter(range(0, 60)) + sampler = RandomSampler(dataset, + generator=generator, + num_samples=50, + replacement=True) + assert len(sampler) == 50 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 50)) + class TestBatchSampler(unittest.TestCase): def setUp(self): From d2eefc90ec198579cb728b596e05c8532a369c9d Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 22 Aug 2020 07:04:14 +0000 Subject: [PATCH 2/4] fix format. test=develop --- python/paddle/fluid/tests/unittests/test_batch_sampler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 0424fc3f88f60..6ec6fdb59f200 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -91,10 +91,8 @@ def test_with_generator(self): def test_with_generator_num_samples(self): dataset = RandomDataset(100, 10) generator = iter(range(0, 60)) - sampler = RandomSampler(dataset, - generator=generator, - num_samples=50, - replacement=True) + sampler = RandomSampler( + dataset, generator=generator, num_samples=50, replacement=True) assert len(sampler) == 50 rets = [] From 4b6c6fee5a6e41d274b4c16b5522bc0ddd9462b8 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 22 Aug 2020 07:36:02 +0000 Subject: [PATCH 3/4] fix import. test=develop --- python/paddle/fluid/dataloader/batch_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 294101e51f7f8..76fb2c97fa579 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -16,7 +16,7 @@ from __future__ import division import numpy as np -from .sampler import Sampler, SequenceSampler +from .sampler import Sampler, SequenceSampler, RandomSampler from .dataset import Dataset, IterableDataset __all__ = ["BatchSampler"] From f19e0d382937ccd7c5523ae335901473d7ad5394 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 22 Aug 2020 08:59:40 +0000 Subject: [PATCH 4/4] fix examples. test=develop --- python/paddle/fluid/dataloader/batch_sampler.py | 1 - python/paddle/fluid/dataloader/sampler.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 76fb2c97fa579..1d180329b7251 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -86,7 +86,6 @@ def __len__(self): # init with sampler sampler = RandomSampler(RandomDataset(100)) bs = BatchSampler(sampler=sampler, - shuffle=True, batch_size=8, drop_last=True) diff --git a/python/paddle/fluid/dataloader/sampler.py b/python/paddle/fluid/dataloader/sampler.py index 06e8077fe0646..5c75fafe8b223 100644 --- a/python/paddle/fluid/dataloader/sampler.py +++ b/python/paddle/fluid/dataloader/sampler.py @@ -177,7 +177,7 @@ def __getitem__(self, idx): def __len__(self): return self.num_samples - sampler = RandomSampler(data_souce=RandomDataset(100)) + sampler = RandomSampler(data_source=RandomDataset(100)) for index in sampler: print(index)