Skip to content

Commit

Permalink
Add EuroSAT datamodule (microsoft#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 authored Nov 15, 2021
1 parent d6d4368 commit 14c57af
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 3 deletions.
14 changes: 14 additions & 0 deletions conf/task_defaults/eurosat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "eurosat"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 13
num_classes: 10
datamodule:
root_dir: "tests/data/eurosat"
batch_size: 128
num_workers: 0
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ EuroSAT
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: EuroSAT
.. autoclass:: EuroSATDataModule

GID-15 (Gaofen Image Dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
23 changes: 22 additions & 1 deletion tests/datasets/test_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import EuroSAT
from torchgeo.datasets import EuroSAT, EuroSATDataModule


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
Expand Down Expand Up @@ -89,3 +89,24 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
EuroSAT(str(tmp_path))


class TestEuroSATDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> EuroSATDataModule:
root = os.path.join("tests", "data", "eurosat")
batch_size = 1
num_workers = 0
dm = EuroSATDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.test_dataloader()))
3 changes: 2 additions & 1 deletion torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import CycloneDataModule, TropicalCycloneWindEstimation
from .etci2021 import ETCI2021, ETCI2021DataModule
from .eurosat import EuroSAT
from .eurosat import EuroSAT, EuroSATDataModule
from .geo import (
GeoDataset,
RasterDataset,
Expand Down Expand Up @@ -108,6 +108,7 @@
"ETCI2021",
"ETCI2021DataModule",
"EuroSAT",
"EuroSATDataModule",
"GID15",
"LandCoverAI",
"LandCoverAIDataModule",
Expand Down
139 changes: 138 additions & 1 deletion torchgeo/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
"""EuroSAT dataset."""

import os
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional

import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize

from .geo import VisionClassificationDataset
from .utils import check_integrity, download_url, extract_archive, rasterio_loader
Expand Down Expand Up @@ -169,3 +173,136 @@ def _extract(self) -> None:
"""Extract the dataset."""
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)


class EuroSATDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the EuroSAT dataset.
Uses the train/val/test splits from the dataset.
"""

band_means = torch.tensor( # type: ignore[attr-defined]
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)

band_stds = torch.tensor( # type: ignore[attr-defined]
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)

def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers

self.norm = Normalize(self.band_means, self.band_stds)

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] = self.norm(sample["image"])
return sample

def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
EuroSAT(self.root_dir)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
COWCCountingDataModule,
CycloneDataModule,
ETCI2021DataModule,
EuroSATDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
RESISC45DataModule,
Expand Down Expand Up @@ -48,6 +49,7 @@
"cowc_counting": (RegressionTask, COWCCountingDataModule),
"cyclone": (RegressionTask, CycloneDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
Expand Down

0 comments on commit 14c57af

Please sign in to comment.