diff --git a/conf/task_defaults/ucmerced.yaml b/conf/task_defaults/ucmerced.yaml new file mode 100644 index 00000000000..c3abc0ea7a8 --- /dev/null +++ b/conf/task_defaults/ucmerced.yaml @@ -0,0 +1,15 @@ +experiment: + task: "ucmerced" + module: + loss: "ce" + classification_model: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + datamodule: + batch_size: 128 + num_workers: 6 + unsupervised_mode: false + val_split_pct: 0.1 + test_split_pct: 0.1 diff --git a/tests/data/ucmerced/UCMerced_LandUse.zip b/tests/data/ucmerced/UCMerced_LandUse.zip index 06c622777ca..9f62ea44b0b 100644 Binary files a/tests/data/ucmerced/UCMerced_LandUse.zip and b/tests/data/ucmerced/UCMerced_LandUse.zip differ diff --git a/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural00.tif b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural00.tif new file mode 100644 index 00000000000..aad10f5b771 Binary files /dev/null and b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural00.tif differ diff --git a/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural01.tif b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural01.tif new file mode 100644 index 00000000000..a72d6f4202c Binary files /dev/null and b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural01.tif differ diff --git a/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural02.tif b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural02.tif new file mode 100644 index 00000000000..45b48835d72 Binary files /dev/null and b/tests/data/ucmerced/UCMerced_LandUse/Images/agricultural/agricultural02.tif differ diff --git a/tests/data/ucmerced/UCMerced_LandUse/Images/airplane/airplane00.tif b/tests/data/ucmerced/UCMerced_LandUse/Images/airplane/airplane00.tif new file mode 100644 index 00000000000..678435af301 Binary files /dev/null and b/tests/data/ucmerced/UCMerced_LandUse/Images/airplane/airplane00.tif differ diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index 993c3c3d9f5..fcd0ea28337 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -28,7 +28,7 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.ucmerced, "download_url", download_url ) - md5 = "95e710774f3ef6d9ecb0cd42e4d0fc23" + md5 = "a42ef8779469d196d8f2971ee135f030" monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined] url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip") monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined] @@ -43,12 +43,12 @@ def test_getitem(self, dataset: UCMerced) -> None: assert isinstance(x["label"], torch.Tensor) def test_len(self, dataset: UCMerced) -> None: - assert len(dataset) == 2 + assert len(dataset) == 4 def test_add(self, dataset: UCMerced) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 4 + assert len(ds) == 8 def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None: UCMerced(root=str(tmp_path), download=True) diff --git a/tests/trainers/test_ucmerced.py b/tests/trainers/test_ucmerced.py new file mode 100644 index 00000000000..3604c8ad0d4 --- /dev/null +++ b/tests/trainers/test_ucmerced.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.trainers import UCMercedDataModule + + +@pytest.fixture(scope="module", params=[True, False]) +def datamodule(request: SubRequest) -> UCMercedDataModule: + root = os.path.join("tests", "data", "ucmerced") + batch_size = 2 + num_workers = 0 + unsupervised_mode = request.param + dm = UCMercedDataModule( + root, + batch_size, + num_workers, + val_split_pct=0.33, + test_split_pct=0.33, + unsupervised_mode=unsupervised_mode, + ) + dm.prepare_data() + dm.setup() + return dm + + +class TestUCMercedDataModule: + def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 536e22f5c7f..9fb4f1d9c41 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -12,6 +12,7 @@ from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask from .so2sat import So2SatClassificationTask, So2SatDataModule from .tasks import ClassificationTask +from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( # Tasks @@ -32,6 +33,8 @@ "SEN12MSSegmentationTask", "So2SatDataModule", "So2SatClassificationTask", + "UCMercedClassificationTask", + "UCMercedDataModule", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/trainers/ucmerced.py b/torchgeo/trainers/ucmerced.py new file mode 100644 index 00000000000..207c7efeb32 --- /dev/null +++ b/torchgeo/trainers/ucmerced.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""UC Merced trainer.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +import torchvision.transforms.functional +from torch.nn.modules import Conv2d, Linear +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import UCMerced +from ..datasets.utils import dataset_split +from .tasks import ClassificationTask + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" +Conv2d.__module__ = "nn.Conv2d" +Linear.__module__ = "nn.Linear" + + +class UCMercedClassificationTask(ClassificationTask): + """LightningModule for training models on the UC Merced Dataset.""" + + num_classes = 21 + + +class UCMercedDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the UC Merced dataset. + + Uses random train/val/test splits. + """ + + band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined] + + band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined] + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 4, + unsupervised_mode: bool = False, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for UCMerced based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.unsupervised_mode = unsupervised_mode + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + c, h, w = sample["image"].shape + if h != 256 or w != 256: + sample["image"] = torchvision.transforms.functional.resize( + sample["image"], size=(256, 256) + ) + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + UCMerced(self.root_dir, download=True, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + + if not self.unsupervised_mode: + + dataset = UCMerced(self.root_dir, transforms=transforms) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + else: + + self.train_dataset = UCMerced(self.root_dir, transforms=transforms) + self.val_dataset, self.test_dataset = None, None # type: ignore[assignment] + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + if self.unsupervised_mode or self.val_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + if self.unsupervised_mode or self.test_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/train.py b/train.py index 8736a47cb1b..bbd77c51b85 100755 --- a/train.py +++ b/train.py @@ -29,6 +29,8 @@ SEN12MSSegmentationTask, So2SatClassificationTask, So2SatDataModule, + UCMercedClassificationTask, + UCMercedDataModule, ) TASK_TO_MODULES_MAPPING: Dict[ @@ -42,6 +44,7 @@ "resisc45": (RESISC45ClassificationTask, RESISC45DataModule), "sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule), "so2sat": (So2SatClassificationTask, So2SatDataModule), + "ucmerced": (UCMercedClassificationTask, UCMercedDataModule), }