Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tests, updated doc-strings for Dummy Datasets #865

Merged
merged 9 commits into from
Aug 16, 2022
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# IDE Settings files
.vscode/


otaj marked this conversation as resolved.
Show resolved Hide resolved

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
54 changes: 44 additions & 10 deletions pl_bolts/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
from torch.utils.data import Dataset

from pl_bolts.utils.stability import under_review


@under_review()
class DummyDataset(Dataset):
"""Generate a dummy dataset.

Expand All @@ -31,6 +28,10 @@ def __init__(self, *shapes, num_samples: int = 10000):
"""
super().__init__()
self.shapes = shapes

if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.num_samples = num_samples

def __len__(self):
Expand All @@ -44,15 +45,23 @@ def __getitem__(self, idx: int):
return sample


@under_review()
class DummyDetectionDataset(Dataset):
"""Generate a dummy dataset for detection.
"""Generate a dummy dataset for object detection.

Example:
>>> from pl_bolts.datasets import DummyDetectionDataset
>>> from torch.utils.data import DataLoader
>>> ds = DummyDetectionDataset()
>>> dl = DataLoader(ds, batch_size=7)
>>> # get first batch
>>> batch = next(iter(dl))
>>> x,y = batch
>>> x.size()
torch.Size([7, 3, 256, 256])
>>> y['boxes'].size()
torch.Size([7, 1, 4])
>>> y['labels'].size()
torch.Size([7, 1])
"""

def __init__(
Expand All @@ -64,6 +73,9 @@ def __init__(
num_samples: how many samples to use in this dataset
"""
super().__init__()
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.img_shape = img_shape
self.num_samples = num_samples
self.num_boxes = num_boxes
Expand All @@ -85,7 +97,6 @@ def __getitem__(self, idx: int):
return img, {"boxes": boxes, "labels": labels}


@under_review()
class RandomDictDataset(Dataset):
"""Generate a dummy dataset with a dict structure.

Expand All @@ -94,14 +105,25 @@ class RandomDictDataset(Dataset):
>>> from torch.utils.data import DataLoader
>>> ds = RandomDictDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> len(batch['a']),len(batch['a'][0])
(7, 10)
>>> len(batch['b']),len(batch['b'][0])
(7, 10)
"""

def __init__(self, size: int, num_samples: int = 250):
"""
Args:
size: tuple
size: integer representing the length of a feature_vector
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

if size < 1:
raise ValueError("Provide an argument greater than 0 for `size`")

self.len = num_samples
self.data = torch.randn(num_samples, size)
otaj marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -114,15 +136,19 @@ def __len__(self):
return self.len


@under_review()
class RandomDictStringDataset(Dataset):
"""Generate a dummy dataset with strings.
"""Generate a dummy dataset with in dict structure with strings as indexes.

Example:
>>> from pl_bolts.datasets import RandomDictStringDataset
>>> from torch.utils.data import DataLoader
>>> ds = RandomDictStringDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> batch['id']
['0', '1', '2', '3', '4', '5', '6']
>>> len(batch['x'])
7
"""

def __init__(self, size: int, num_samples: int = 250):
Expand All @@ -131,6 +157,9 @@ def __init__(self, size: int, num_samples: int = 250):
size: tuple
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.len = num_samples
self.data = torch.randn(num_samples, size)

Expand All @@ -141,7 +170,6 @@ def __len__(self):
return self.len


@under_review()
class RandomDataset(Dataset):
"""Generate a dummy dataset.

Expand All @@ -150,6 +178,9 @@ class RandomDataset(Dataset):
>>> from torch.utils.data import DataLoader
>>> ds = RandomDataset(10)
>>> dl = DataLoader(ds, batch_size=7)
>>> batch = next(iter(dl))
>>> len(batch),len(batch[0])
(7, 10)
"""

def __init__(self, size: int, num_samples: int = 250):
Expand All @@ -158,6 +189,9 @@ def __init__(self, size: int, num_samples: int = 250):
size: tuple
num_samples: number of samples
"""
if num_samples < 1:
raise ValueError("Provide an argument greater than 0 for `num_samples`")

self.len = num_samples
self.data = torch.randn(num_samples, size)

Expand Down
145 changes: 124 additions & 21 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,144 @@
import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST


def test_dummy_ds():
ds = DummyDataset((1, 2), num_samples=100)
dl = DataLoader(ds)
@pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)])
def test_dummy_ds(catch_warnings, batch_size, num_samples):

for b in dl:
pass
if num_samples > 0:

ds = DummyDataset((1, 28, 28), (1,), num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

def test_rand_ds():
ds = RandomDataset(32, num_samples=100)
dl = DataLoader(ds)
assert isinstance(ds, Dataset)
assert num_samples == len(ds)

for b in dl:
pass
x = next(iter(ds))
assert x[0].shape == torch.Size([1, 28, 28])
assert x[1].shape == torch.Size([1])

batch = next(iter(dl))
assert batch[0].shape == torch.Size([batch_size, 1, 28, 28])
assert batch[1].shape == torch.Size([batch_size, 1])

def test_rand_dict_ds():
ds = RandomDictDataset(32, num_samples=100)
dl = DataLoader(ds)
else:
with pytest.raises(Exception) as e:
otaj marked this conversation as resolved.
Show resolved Hide resolved
ds = DummyDataset((1, 28, 28), (1,), num_samples=num_samples)

for b in dl:
pass
assert e.type == ValueError
assert str(e.value) == "Provide an argument greater than 0 for `num_samples`"


def test_rand_str_dict_ds():
ds = RandomDictStringDataset(32, num_samples=100)
dl = DataLoader(ds)
@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_dict_ds(catch_warnings, batch_size, size, num_samples):

for b in dl:
pass
if num_samples > 0 or size > 0:
ds = RandomDictDataset(size, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert x["a"].shape == torch.Size([size])
assert x["b"].shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch["a"]), len(batch["a"][0]) == (batch_size, size)
assert len(batch["b"]), len(batch["b"][0]) == (batch_size, size)
else:
with pytest.raises(Exception) as e:
ds = RandomDictDataset(size, num_samples=num_samples)

assert e.type == ValueError
assert str(e.value) in {
"Provide an argument greater than 0 for `num_samples`",
"Provide an argument greater than 0 for `size`",
}


@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_ds(catch_warnings, batch_size, size, num_samples):
if num_samples > 0 and size > 0:
ds = RandomDataset(size=size, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert x.shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch), len(batch[0]) == (batch_size, size)

else:
with pytest.raises(Exception) as e:
ds = RandomDataset(size, num_samples=num_samples)

assert e.type == ValueError
assert str(e.value) in {
"Provide an argument greater than 0 for `num_samples`",
"Provide an argument greater than 0 for `size`",
}


@pytest.mark.parametrize("batch_size,size,num_samples", [(16, 32, 100), (1, 0, 0)])
def test_rand_str_dict_ds(catch_warnings, batch_size, size, num_samples):

if num_samples > 0 and size > 0:
ds = RandomDictStringDataset(size=size, num_samples=100)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

x = next(iter(ds))
assert isinstance(x["id"], str)
assert x["x"].shape == torch.Size([size])

batch = next(iter(dl))
assert len(batch["x"]) == batch_size
assert len(batch["id"]) == batch_size
else:
with pytest.raises(Exception) as e:
ds = RandomDictStringDataset(size, num_samples=num_samples)

assert e.type == ValueError
assert str(e.value) in {
"Provide an argument greater than 0 for `num_samples`",
"Provide an argument greater than 0 for `size`",
}


@pytest.mark.parametrize("batch_size,img_shape,num_samples", [(16, (3, 256, 256), 100), (1, (256, 256), 0)])
def test_dummy_detection_ds(catch_warnings, batch_size, img_shape, num_samples):
if num_samples > 0:
ds = DummyDetectionDataset(img_shape=img_shape, num_boxes=3, num_classes=3, num_samples=num_samples)
dl = DataLoader(ds, batch_size=batch_size)

assert isinstance(ds, Dataset)
assert num_samples == len(ds)

batch = next(iter(dl))
x, y = batch
assert x.size() == torch.Size([batch_size, *img_shape])
assert y["boxes"].size() == torch.Size([batch_size, 3, 4])
assert y["labels"].size() == torch.Size([batch_size, 3])

else:
with pytest.raises(Exception) as e:
ds = DummyDetectionDataset(img_shape=img_shape, num_boxes=3, num_classes=3, num_samples=num_samples)
assert e.type == ValueError
assert str(e.value) in {
"Provide an argument greater than 0 for `num_samples`",
"Provide an argument greater than 0 for `size`",
}


@pytest.mark.parametrize("scale_factor", [2, 4])
Expand Down