Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UC Merced trainer #208

Merged
merged 3 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions conf/task_defaults/ucmerced.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
experiment:
task: "ucmerced"
module:
loss: "ce"
classification_model: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
datamodule:
batch_size: 128
num_workers: 6
unsupervised_mode: false
val_split_pct: 0.1
test_split_pct: 0.1
Binary file modified tests/data/ucmerced/UCMerced_LandUse.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/datasets/test_ucmerced.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def dataset(
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.ucmerced, "download_url", download_url
)
md5 = "95e710774f3ef6d9ecb0cd42e4d0fc23"
md5 = "a42ef8779469d196d8f2971ee135f030"
monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip")
monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined]
Expand All @@ -43,12 +43,12 @@ def test_getitem(self, dataset: UCMerced) -> None:
assert isinstance(x["label"], torch.Tensor)

def test_len(self, dataset: UCMerced) -> None:
assert len(dataset) == 2
assert len(dataset) == 4

def test_add(self, dataset: UCMerced) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4
assert len(ds) == 8

def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None:
UCMerced(root=str(tmp_path), download=True)
Expand Down
39 changes: 39 additions & 0 deletions tests/trainers/test_ucmerced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.trainers import UCMercedDataModule


@pytest.fixture(scope="module", params=[True, False])
def datamodule(request: SubRequest) -> UCMercedDataModule:
root = os.path.join("tests", "data", "ucmerced")
batch_size = 2
num_workers = 0
unsupervised_mode = request.param
dm = UCMercedDataModule(
root,
batch_size,
num_workers,
val_split_pct=0.33,
test_split_pct=0.33,
unsupervised_mode=unsupervised_mode,
)
dm.prepare_data()
dm.setup()
return dm


class TestUCMercedDataModule:
def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.test_dataloader()))
3 changes: 3 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule
from .tasks import ClassificationTask
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule

__all__ = (
# Tasks
Expand All @@ -32,6 +33,8 @@
"SEN12MSSegmentationTask",
"So2SatDataModule",
"So2SatClassificationTask",
"UCMercedClassificationTask",
"UCMercedDataModule",
)

# https://stackoverflow.com/questions/40018681
Expand Down
142 changes: 142 additions & 0 deletions torchgeo/trainers/ucmerced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""UC Merced trainer."""

from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
import torchvision.transforms.functional
from torch.nn.modules import Conv2d, Linear
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize

from ..datasets import UCMerced
from ..datasets.utils import dataset_split
from .tasks import ClassificationTask

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"


class UCMercedClassificationTask(ClassificationTask):
"""LightningModule for training models on the UC Merced Dataset."""

num_classes = 21


class UCMercedDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the UC Merced dataset.

Uses random train/val/test splits.
"""

band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined]

band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined]

def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
unsupervised_mode: bool = False,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.

Args:
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.unsupervised_mode = unsupervised_mode

self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct

self.norm = Normalize(self.band_means, self.band_stds)

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
c, h, w = sample["image"].shape
if h != 256 or w != 256:
sample["image"] = torchvision.transforms.functional.resize(
sample["image"], size=(256, 256)
)
sample["image"] = self.norm(sample["image"])
return sample

def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.

This method is only called once per run.
"""
UCMerced(self.root_dir, download=True, checksum=False)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.

This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])

if not self.unsupervised_mode:

dataset = UCMerced(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
else:

self.train_dataset = UCMerced(self.root_dir, transforms=transforms)
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
if self.unsupervised_mode or self.val_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
if self.unsupervised_mode or self.test_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
SEN12MSSegmentationTask,
So2SatClassificationTask,
So2SatDataModule,
UCMercedClassificationTask,
UCMercedDataModule,
)

TASK_TO_MODULES_MAPPING: Dict[
Expand All @@ -42,6 +44,7 @@
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
"so2sat": (So2SatClassificationTask, So2SatDataModule),
"ucmerced": (UCMercedClassificationTask, UCMercedDataModule),
}


Expand Down