diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index a18a05abfe6..15a531437c5 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -9,6 +9,7 @@ data: class_path: SKIPPDDataModule init_args: batch_size: 1 + val_split_pct: 0.4 dict_kwargs: root: "tests/data/skippd" download: true diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py index b96c6032c67..bf9b471fa63 100644 --- a/tests/datamodules/test_utils.py +++ b/tests/datamodules/test_utils.py @@ -5,28 +5,8 @@ import numpy as np import pytest -import torch -from torch.utils.data import TensorDataset -from torchgeo.datamodules.utils import dataset_split, group_shuffle_split - - -def test_dataset_split() -> None: - num_samples = 24 - x = torch.ones(num_samples, 5) - y = torch.randint(low=0, high=2, size=(num_samples,)) - ds = TensorDataset(x, y) - - # Test only train/val set split - train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) - assert len(train_ds) == round(num_samples / 2) - assert len(val_ds) == round(num_samples / 2) - - # Test train/val/test set split - train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) - assert len(train_ds) == round(num_samples / 3) - assert len(val_ds) == round(num_samples / 3) - assert len(test_ds) == round(num_samples / 3) +from torchgeo.datamodules.utils import group_shuffle_split def test_group_shuffle_split() -> None: diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 2195aca3f0f..92ce7b96e07 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -6,13 +6,14 @@ from typing import Any import kornia.augmentation as K +import torch +from torch.utils.data import random_split from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule -from .utils import dataset_split class DeepGlobeLandCoverDataModule(NonGeoDataModule): @@ -59,8 +60,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = DeepGlobeLandCover(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 9f6b2f5da2f..3597cf3d4cf 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -6,13 +6,14 @@ from typing import Any import kornia.augmentation as K +import torch +from torch.utils.data import random_split from ..datasets import GID15 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule -from .utils import dataset_split class GID15DataModule(NonGeoDataModule): @@ -66,8 +67,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = GID15(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: # Test set masks are not public, use for prediction instead diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 2107928d456..c77805f02a7 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -6,8 +6,9 @@ from typing import Any import kornia.augmentation as K +import torch +from torch.utils.data import random_split -from torchgeo.datamodules.utils import dataset_split from torchgeo.samplers.utils import _to_tuple from ..datasets import LEVIRCD, LEVIRCDPlus @@ -113,8 +114,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = LEVIRCDPlus(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 76848bc4e4b..36e3e78d480 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -7,11 +7,12 @@ import kornia.augmentation as K import torch +from torch.utils.data import random_split from ..datasets import NASAMarineDebris from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, collate_fn_detection, dataset_split +from .utils import AugPipe, collate_fn_detection class NASAMarineDebrisDataModule(NonGeoDataModule): @@ -61,6 +62,13 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.dataset = NASAMarineDebris(**self.kwargs) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset, self.test_dataset = random_split( + self.dataset, + [ + 1 - self.val_split_pct - self.test_split_pct, + self.val_split_pct, + self.test_split_pct, + ], + generator, ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 61c5ca9aa87..3a5cc283466 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -7,13 +7,13 @@ import kornia.augmentation as K import torch +from torch.utils.data import random_split from ..datasets import OSCD from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule -from .utils import dataset_split MEAN = { "B01": 1583.0741, @@ -99,8 +99,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = OSCD(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = OSCD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 397ef25d7b5..610d1d35056 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -6,13 +6,14 @@ from typing import Any import kornia.augmentation as K +import torch +from torch.utils.data import random_split from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule -from .utils import dataset_split class Potsdam2DDataModule(NonGeoDataModule): @@ -61,8 +62,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = Potsdam2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = Potsdam2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/skippd.py b/torchgeo/datamodules/skippd.py index b76eb3e1e92..e4915c4f18a 100644 --- a/torchgeo/datamodules/skippd.py +++ b/torchgeo/datamodules/skippd.py @@ -5,9 +5,11 @@ from typing import Any +import torch +from torch.utils.data import random_split + from ..datasets import SKIPPD from .geo import NonGeoDataModule -from .utils import dataset_split class SKIPPDDataModule(NonGeoDataModule): @@ -47,8 +49,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = SKIPPD(split="trainval", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = SKIPPD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 3a3b7531cba..6c69cd1df02 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -6,12 +6,13 @@ from typing import Any import kornia.augmentation as K +import torch from torch import Tensor +from torch.utils.data import random_split from ..datasets import SpaceNet1 from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import dataset_split class SpaceNet1DataModule(NonGeoDataModule): @@ -68,8 +69,15 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.dataset = SpaceNet1(**self.kwargs) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, self.val_split_pct, self.test_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset, self.test_dataset = random_split( + self.dataset, + [ + 1 - self.val_split_pct - self.test_split_pct, + self.val_split_pct, + self.test_split_pct, + ], + generator, ) def on_after_batch_transfer( diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index a6069250d71..b4a1b5f5d06 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -10,11 +10,8 @@ import numpy as np import torch from einops import rearrange -from torch import Generator, Tensor +from torch import Tensor from torch.nn import Module -from torch.utils.data import Subset, TensorDataset, random_split - -from ..datasets import NonGeoDataset # Based on lightning_lite.utilities.exceptions @@ -102,44 +99,6 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: return output -def dataset_split( - dataset: TensorDataset | NonGeoDataset, - val_pct: float, - test_pct: float | None = None, -) -> list[Subset[Any]]: - """Split a torch Dataset into train/val/test sets. - - If ``test_pct`` is not set then only train and validation splits are returned. - - .. deprecated:: 0.4 - Use :func:`torch.utils.data.random_split` instead, ``random_split`` - now supports percentages as of PyTorch 1.13. - - Args: - dataset: dataset to be split into train/val or train/val/test subsets - val_pct: percentage of samples to be in validation set - test_pct: (Optional) percentage of samples to be in test set - - Returns: - a list of the subset datasets. Either [train, val] or [train, val, test] - """ - if test_pct is None: - val_length = round(len(dataset) * val_pct) - train_length = len(dataset) - val_length - return random_split( - dataset, [train_length, val_length], generator=Generator().manual_seed(0) - ) - else: - val_length = round(len(dataset) * val_pct) - test_length = round(len(dataset) * test_pct) - train_length = len(dataset) - (val_length + test_length) - return random_split( - dataset, - [train_length, val_length, test_length], - generator=Generator().manual_seed(0), - ) - - def group_shuffle_split( groups: Iterable[Any], train_size: float | None = None, diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 441fafdd9d2..fa0597cc11f 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -6,13 +6,14 @@ from typing import Any import kornia.augmentation as K +import torch +from torch.utils.data import random_split from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule -from .utils import dataset_split class Vaihingen2DDataModule(NonGeoDataModule): @@ -61,8 +62,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = Vaihingen2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = Vaihingen2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 0bafef71b27..af42c7a16d1 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -7,12 +7,13 @@ import kornia.augmentation as K import torch +from torch.utils.data import random_split from ..datasets import VHR10 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, collate_fn_detection, dataset_split +from .utils import AugPipe, collate_fn_detection class VHR10DataModule(NonGeoDataModule): @@ -78,6 +79,13 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.dataset = VHR10(**self.kwargs) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, self.val_split_pct, self.test_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset, self.test_dataset = random_split( + self.dataset, + [ + 1 - self.val_split_pct - self.test_split_pct, + self.val_split_pct, + self.test_split_pct, + ], + generator, ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 8f96d786bea..5891d039ba0 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -5,9 +5,11 @@ from typing import Any +import torch +from torch.utils.data import random_split + from ..datasets import XView2 from .geo import NonGeoDataModule -from .utils import dataset_split class XView2DataModule(NonGeoDataModule): @@ -46,8 +48,9 @@ def setup(self, stage: str) -> None: """ if stage in ["fit", "validate"]: self.dataset = XView2(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator ) if stage in ["test"]: self.test_dataset = XView2(split="test", **self.kwargs)