Skip to content

Commit

Permalink
add dataset_split unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Oct 10, 2021
1 parent 65ad36e commit a71fa87
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import torch
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import TensorDataset

import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
collate_dict,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
Expand Down Expand Up @@ -335,3 +337,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:

with working_dir(str(subdir), create=True):
assert subdir.cwd() == subdir


def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)

# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2

# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3

0 comments on commit a71fa87

Please sign in to comment.