Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tests, updated doc-strings for Dummy Datasets #865

Merged
merged 9 commits into from
Aug 16, 2022
61 changes: 48 additions & 13 deletions pl_bolts/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
from torch.utils.data import Dataset

from pl_bolts.utils.stability import under_review


@under_review()
class DummyDataset(Dataset):
"""Generate a dummy dataset.

Expand All @@ -31,6 +28,10 @@ def __init__(self, *shapes, num_samples: int = 10000):
"""
super().__init__()
self.shapes = shapes

if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.num_samples = num_samples

def __len__(self):
Expand All @@ -44,15 +45,23 @@ def __getitem__(self, idx: int):
return sample


@under_review()
class DummyDetectionDataset(Dataset):
"""Generate a dummy dataset for detection.
"""Generate a dummy dataset for object detection.

Example:
>>> from pl_bolts.datasets import DummyDetectionDataset
>>> from torch.utils.data import DataLoader
>>> ds = DummyDetectionDataset()
>>> dl = DataLoader(ds, batch_size=7)
>>> # get first batch
>>> batch = next(iter(dl))
>>> x,y = batch
>>> x.size()
torch.Size([7, 3, 256, 256])
>>> y['boxes'].size()
torch.Size([7, 1, 4])
>>> y['labels'].size()
torch.Size([7, 1])
"""

def __init__(
Expand All @@ -64,6 +73,9 @@ def __init__(
num_samples: how many samples to use in this dataset
"""
super().__init__()
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.img_shape = img_shape
self.num_samples = num_samples
self.num_boxes = num_boxes
Expand All @@ -85,7 +97,6 @@ def __getitem__(self, idx: int):
return img, {"boxes": boxes, "labels": labels}


@under_review()
class RandomDictDataset(Dataset):
"""Generate a dummy dataset with a dict structure.

Expand All @@ -94,35 +105,51 @@ class RandomDictDataset(Dataset):
>>> from torch.utils.data import DataLoader
>>> ds = RandomDictDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> len(batch['a']),len(batch['a'][0])
(7, 10)
>>> len(batch['b']),len(batch['b'][0])
(7, 10)
"""

def __init__(self, size: int, num_samples: int = 250):
"""
Args:
size: tuple
size: integer representing the length of a feature_vector
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

if size < 1:
raise ValueError("Provide an argument greater than 0 for `size`")

self.len = num_samples
self.data = torch.randn(num_samples, size)
self.data_a = torch.randn(num_samples, size)
self.data_b = torch.randn(num_samples, size)

def __getitem__(self, index):
a = self.data[index]
b = a + 2
a = self.data_a[index]
b = self.data_b[index]
return {"a": a, "b": b}

def __len__(self):
return self.len


@under_review()
class RandomDictStringDataset(Dataset):
"""Generate a dummy dataset with strings.
"""Generate a dummy dataset with in dict structure with strings as indexes.

Example:
>>> from pl_bolts.datasets import RandomDictStringDataset
>>> from torch.utils.data import DataLoader
>>> ds = RandomDictStringDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> batch['id']
['0', '1', '2', '3', '4', '5', '6']
>>> len(batch['x'])
7
"""

def __init__(self, size: int, num_samples: int = 250):
Expand All @@ -131,6 +158,9 @@ def __init__(self, size: int, num_samples: int = 250):
size: tuple
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.len = num_samples
self.data = torch.randn(num_samples, size)

Expand All @@ -141,7 +171,6 @@ def __len__(self):
return self.len


@under_review()
class RandomDataset(Dataset):
"""Generate a dummy dataset.

Expand All @@ -150,6 +179,9 @@ class RandomDataset(Dataset):
>>> from torch.utils.data import DataLoader
>>> ds = RandomDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> len(batch),len(batch[0])
(7, 10)
"""

def __init__(self, size: int, num_samples: int = 250):
Expand All @@ -158,6 +190,9 @@ def __init__(self, size: int, num_samples: int = 250):
size: tuple
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.len = num_samples
self.data = torch.randn(num_samples, size)

Expand Down
119 changes: 98 additions & 21 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,118 @@
import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST


def test_dummy_ds():
ds = DummyDataset((1, 2), num_samples=100)
dl = DataLoader(ds)
@pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)])
def test_dummy_ds(catch_warnings, batch_size, num_samples):

for b in dl:
pass
if num_samples > 0:

ds = DummyDataset((1, 28, 28), (1,), num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

def test_rand_ds():
ds = RandomDataset(32, num_samples=100)
dl = DataLoader(ds)
assert isinstance(ds, Dataset)
assert num_samples == len(ds)

for b in dl:
pass
x = next(iter(ds))
assert x[0].shape == torch.Size([1, 28, 28])
assert x[1].shape == torch.Size([1])

batch = next(iter(dl))
assert batch[0].shape == torch.Size([batch_size, 1, 28, 28])
assert batch[1].shape == torch.Size([batch_size, 1])

def test_rand_dict_ds():
ds = RandomDictDataset(32, num_samples=100)
dl = DataLoader(ds)
else:
with pytest.raises(ValueError, match="Provide an argument greater than 0"):
ds = DummyDataset((1, 28, 28), (1,), num_samples=num_samples)

for b in dl:
pass

@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_dict_ds(catch_warnings, batch_size, size, num_samples):

def test_rand_str_dict_ds():
ds = RandomDictStringDataset(32, num_samples=100)
dl = DataLoader(ds)
if num_samples > 0 or size > 0:
ds = RandomDictDataset(size, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

for b in dl:
pass
assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert x["a"].shape == torch.Size([size])
assert x["b"].shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch["a"]), len(batch["a"][0]) == (batch_size, size)
assert len(batch["b"]), len(batch["b"][0]) == (batch_size, size)
else:
with pytest.raises(ValueError, match="Provide an argument greater than 0"):
ds = RandomDictDataset(size, num_samples=num_samples)


@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_ds(catch_warnings, batch_size, size, num_samples):
if num_samples > 0 and size > 0:
ds = RandomDataset(size=size, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert x.shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch), len(batch[0]) == (batch_size, size)

else:
with pytest.raises(ValueError, match="Provide an argument greater than 0"):
ds = RandomDataset(size, num_samples=num_samples)


@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_str_dict_ds(catch_warnings, batch_size, size, num_samples):

if num_samples > 0 and size > 0:
ds = RandomDictStringDataset(size=size, num_samples=100)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert isinstance(x["id"], str)
assert x["x"].shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch["x"]) == batch_size
assert len(batch["id"]) == batch_size
else:
with pytest.raises(ValueError, match="Provide an argument greater than 0"):
ds = RandomDictStringDataset(size, num_samples=num_samples)


@pytest.mark.parametrize("batch_size,img_shape,num_samples", [(16, (3, 256, 256), 100), (1, (256, 256), 0)])
def test_dummy_detection_ds(catch_warnings, batch_size, img_shape, num_samples):
if num_samples > 0:
ds = DummyDetectionDataset(img_shape=img_shape, num_boxes=3, num_classes=3, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

batch = next(iter(dl))
x, y = batch
assert x.size() == torch.Size([batch_size, *img_shape])
assert y["boxes"].size() == torch.Size([batch_size, 3, 4])
assert y["labels"].size() == torch.Size([batch_size, 3])

else:
with pytest.raises(ValueError, match="Provide an argument greater than 0"):
ds = DummyDetectionDataset(img_shape=img_shape, num_boxes=3, num_classes=3, num_samples=num_samples)


@pytest.mark.parametrize("scale_factor", [2, 4])
Expand Down