Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigEarthNet Trainers #211

Merged
merged 20 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
trainer:
gpus: 1 # single GPU training
min_epochs: 10
max_epochs: 40
benchmark: True

experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 14
datamodule:
batch_size: 128
num_workers: 6
bands: "all"
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 13 additions & 0 deletions conf/task_defaults/bigearthnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 14
datamodule:
batch_size: 128
num_workers: 6
bands: "all"
Binary file modified tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz
Binary file not shown.
Binary file modified tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/datasets/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_getitem(self, dataset: BigEarthNet) -> None:
assert x["image"].shape == (12, 120, 120)

def test_len(self, dataset: BigEarthNet) -> None:
assert len(dataset) == 2
assert len(dataset) == 4

def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
Expand Down
39 changes: 39 additions & 0 deletions tests/trainers/test_bigearthnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.trainers import BigEarthNetDataModule


class TestBigEarthNetDataModule:
@pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False]))
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
bands, unsupervised_mode = request.param
root = os.path.join("tests", "data", "bigearthnet")
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(
root,
bands,
batch_size,
num_workers,
unsupervised_mode,
val_split_pct=0.3,
test_split_pct=0.3,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.test_dataloader()))
177 changes: 152 additions & 25 deletions tests/trainers/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,114 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Generator, Tuple, cast
from typing import Any, Dict, Generator, Optional, cast

import pytest
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torchgeo.trainers import (
ClassificationTask,
CycloneDataModule,
MultiLabelClassificationTask,
RegressionTask,
So2SatDataModule,
)

from .test_utils import mocked_log


@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)])
def bands(request: SubRequest) -> Tuple[str, int]:
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
return cast(Tuple[str, int], request.param)
class DummyDataset(Dataset): # type: ignore[type-arg]
def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None:
x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w)
y = torch.randint( # type: ignore[attr-defined]
0, num_classes, size=(10,)
) # (b,)

if multilabel:
y = F.one_hot(y, num_classes=num_classes) # (b, classes)

@pytest.fixture(scope="module", params=[True, False])
def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule:
band_set = bands[0]
unsupervised_mode = request.param
root = os.path.join("tests", "data", "so2sat")
batch_size = 2
num_workers = 0
dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode)
dm.prepare_data()
dm.setup()
return dm
self.dataset = TensorDataset(x, y)

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx: int) -> Dict[str, Tensor]:
x, y = self.dataset[idx]
sample = {"image": x, "label": y}
return sample


class DummyDataModule(pl.LightningDataModule):
def __init__(
self,
num_channels: int,
num_classes: int,
multilabel: bool,
batch_size: int = 1,
num_workers: int = 0,
) -> None:
super().__init__() # type: ignore[no-untyped-call]
self.num_channels = num_channels
self.num_classes = num_classes
self.multilabel = multilabel
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage: Optional[str] = None) -> None:
self.dataset = DummyDataset(
num_channels=self.num_channels,
num_classes=self.num_classes,
multilabel=self.multilabel,
)

def train_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)

def val_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)

def test_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)


class TestClassificationTask:
@pytest.fixture(scope="class", params=[2, 3, 5])
def datamodule(self, request: SubRequest) -> DummyDataModule:
dm = DummyDataModule(
num_channels=request.param,
num_classes=45,
multilabel=False,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture(
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"])
scope="class",
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
)
def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]:
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml"))
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
task_args["in_channels"] = bands[1]
def config(
self, request: SubRequest, datamodule: DummyDataModule
) -> Dict[str, Any]:
task_args = {}
task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
loss, weights = request.param
task_args["loss"] = loss
task_args["weights"] = weights
Expand All @@ -65,20 +129,20 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None:
assert "lr_scheduler" in out

def test_training(
self, datamodule: So2SatDataModule, task: ClassificationTask
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)

def test_validation(
self, datamodule: So2SatDataModule, task: ClassificationTask
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)

def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None:
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)
Expand All @@ -99,6 +163,7 @@ def test_invalid_model(self, config: Dict[str, Any]) -> None:

def test_invalid_loss(self, config: Dict[str, Any]) -> None:
config["loss"] = "invalid_loss"
config["classification_model"] = "resnet18"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
ClassificationTask(**config)
Expand All @@ -117,6 +182,68 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No
ClassificationTask(**config)


class TestMultiLabelClassificationTask:
@pytest.fixture(scope="class")
def datamodule(self, request: SubRequest) -> DummyDataModule:
dm = DummyDataModule(
num_channels=3,
num_classes=43,
multilabel=True,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"]))
def config(
self, datamodule: DummyDataModule, request: SubRequest
) -> Dict[str, Any]:
task_args = {}
task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
loss, weights = request.param
task_args["loss"] = loss
task_args["weights"] = weights
return task_args

@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> MultiLabelClassificationTask:
task = MultiLabelClassificationTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task

def test_training(
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)

def test_validation(
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)

def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)

def test_invalid_loss(self, config: Dict[str, Any]) -> None:
config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
MultiLabelClassificationTask(**config)


class TestRegressionTask:
@pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule:
Expand Down
23 changes: 15 additions & 8 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""TorchGeo trainers."""

from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule
from .byol import BYOLTask
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule
Expand All @@ -11,29 +12,35 @@
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule
from .tasks import ClassificationTask, RegressionTask
from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule

__all__ = (
# Tasks
"ClassificationTask",
"RegressionTask",
# Trainers
"BigEarthNetClassificationTask",
"BYOLTask",
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
"ChesapeakeCVPRSegmentationTask",
"ChesapeakeCVPRDataModule",
"ClassificationTask",
"CycloneDataModule",
"LandcoverAIDataModule",
"LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule",
"MultiLabelClassificationTask",
"NAIPChesapeakeSegmentationTask",
"RESISC45ClassificationTask",
"RESISC45DataModule",
"SEN12MSDataModule",
"RegressionTask",
"SEN12MSSegmentationTask",
"So2SatDataModule",
"So2SatClassificationTask",
"UCMercedClassificationTask",
# DataModules
"BigEarthNetDataModule",
"ChesapeakeCVPRDataModule",
"CycloneDataModule",
"LandcoverAIDataModule",
"NAIPChesapeakeDataModule",
"RESISC45DataModule",
"SEN12MSDataModule",
"So2SatDataModule",
"UCMercedDataModule",
)

Expand Down
Loading