Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove torchgeo.datamodules.utils.dataset_split #2005

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update syntax, add generator
adamjstewart committed Apr 17, 2024
commit 00b19f44e66334e00a33f279b3999a657925cc2f
4 changes: 3 additions & 1 deletion torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import random_split

from ..datasets import DeepGlobeLandCover
@@ -59,8 +60,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = DeepGlobeLandCover(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
self.train_dataset, self.val_dataset = random_split(
self.dataset, self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs)
4 changes: 3 additions & 1 deletion torchgeo/datamodules/gid15.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import random_split

from ..datasets import GID15
@@ -66,8 +67,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = GID15(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
# Test set masks are not public, use for prediction instead
4 changes: 3 additions & 1 deletion torchgeo/datamodules/levircd.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import random_split

from torchgeo.samplers.utils import _to_tuple
@@ -113,8 +114,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = LEVIRCDPlus(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, val_pct=self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs)
9 changes: 8 additions & 1 deletion torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
@@ -62,6 +62,13 @@ def setup(self, stage: str) -> None:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = NASAMarineDebris(**self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
self.dataset,
[
1 - self.val_split_pct - self.test_split_pct,
self.val_split_pct,
self.test_split_pct,
],
generator,
)
3 changes: 2 additions & 1 deletion torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
@@ -99,8 +99,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = OSCD(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, val_pct=self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = OSCD(split="test", **self.kwargs)
4 changes: 3 additions & 1 deletion torchgeo/datamodules/potsdam.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import random_split

from ..datasets import Potsdam2D
@@ -61,8 +62,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = Potsdam2D(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = Potsdam2D(split="test", **self.kwargs)
4 changes: 3 additions & 1 deletion torchgeo/datamodules/skippd.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from typing import Any

import torch
from torch.utils.data import random_split

from ..datasets import SKIPPD
@@ -48,8 +49,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = SKIPPD(split="trainval", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, val_pct=self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = SKIPPD(split="test", **self.kwargs)
10 changes: 9 additions & 1 deletion torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor
from torch.utils.data import random_split

@@ -68,8 +69,15 @@ def setup(self, stage: str) -> None:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = SpaceNet1(**self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
self.dataset, self.val_split_pct, self.test_split_pct
self.dataset,
[
1 - self.val_split_pct - self.test_split_pct,
self.val_split_pct,
self.test_split_pct,
],
generator,
)

def on_after_batch_transfer(
4 changes: 3 additions & 1 deletion torchgeo/datamodules/vaihingen.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import random_split

from ..datasets import Vaihingen2D
@@ -61,8 +62,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = Vaihingen2D(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = Vaihingen2D(split="test", **self.kwargs)
12 changes: 10 additions & 2 deletions torchgeo/datamodules/vhr10.py
Original file line number Diff line number Diff line change
@@ -7,11 +7,12 @@

import kornia.augmentation as K
import torch
from torch.utils.data import AugPipe, collate_fn_detection, random_split
from torch.utils.data import random_split

from ..datasets import VHR10
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.utils import AugPipe, collate_fn_detection
from .geo import NonGeoDataModule


@@ -78,6 +79,13 @@ def setup(self, stage: str) -> None:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = VHR10(**self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
self.dataset, self.val_split_pct, self.test_split_pct
self.dataset,
[
1 - self.val_split_pct - self.test_split_pct,
self.val_split_pct,
self.test_split_pct,
],
generator,
)
4 changes: 3 additions & 1 deletion torchgeo/datamodules/xview.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from typing import Any

import torch
from torch.utils.data import random_split

from ..datasets import XView2
@@ -47,8 +48,9 @@ def setup(self, stage: str) -> None:
"""
if stage in ["fit", "validate"]:
self.dataset = XView2(split="train", **self.kwargs)
generator = torch.Generator().manual_seed(0)
self.train_dataset, self.val_dataset = random_split(
self.dataset, val_pct=self.val_split_pct
self.dataset, [1 - self.val_split_pct, self.val_split_pct], generator
)
if stage in ["test"]:
self.test_dataset = XView2(split="test", **self.kwargs)