Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: split tasks into separate files, add SemanticSegmentationTask #224

Merged
merged 5 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion conf/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ experiment:
encoder_output_stride: 16
learning_rate: 1e-2
learning_rate_schedule_patience: 6
class_set: 7
in_channels: 4
num_classes: 7
num_filters: 256
datamodule:
batch_size: 64
num_workers: 6
Expand Down
3 changes: 3 additions & 0 deletions conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ experiment:
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 6
num_filters: 256
datamodule:
batch_size: 32
num_workers: 6
6 changes: 4 additions & 2 deletions conf/task_defaults/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ experiment:
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 6
class_set: 7
in_channels: 4
num_classes: 7
num_filters: 256
datamodule:
train_state: "de"
patches_per_tile: 200
patch_size: 256
batch_size: 64
num_workers: 4
class_set: ${experiment.module.class_set}
num_classes: ${experiment.module.num_classes}
3 changes: 3 additions & 0 deletions conf/task_defaults/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 6
num_filters: 256
datamodule:
batch_size: 32
num_workers: 4
3 changes: 3 additions & 0 deletions conf/task_defaults/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ experiment:
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 2
in_channels: 4
num_classes: 13
num_filters: 64
datamodule:
batch_size: 32
num_workers: 4
1 change: 1 addition & 0 deletions conf/task_defaults/sen12ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 2
in_channels: 15
num_classes: 11
datamodule:
batch_size: 32
num_workers: 4
2 changes: 1 addition & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LandCoverAI
.. autoclass:: LandcoverAIDataModule
.. autoclass:: LandCoverAIDataModule

LEVIR-CD+ (LEVIR Change Detection +)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion experiments/test_chesapeakecvpr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.trainers import ChesapeakeCVPRSegmentationTask
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask

ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]]

Expand Down
14 changes: 7 additions & 7 deletions tests/datasets/test_landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import LandCoverAI, LandcoverAIDataModule
from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule


def download_url(url: str, root: str, *args: str) -> None:
Expand Down Expand Up @@ -69,22 +69,22 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
LandCoverAI(str(tmp_path))


class TestLandcoverAIDataModule:
class TestLandCoverAIDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> LandcoverAIDataModule:
def datamodule(self) -> LandCoverAIDataModule:
root = os.path.join("tests", "data", "landcoverai")
batch_size = 2
num_workers = 0
dm = LandcoverAIDataModule(root, batch_size, num_workers)
dm = LandCoverAIDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: LandcoverAIDataModule) -> None:
def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: LandcoverAIDataModule) -> None:
def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: LandcoverAIDataModule) -> None:
def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.test_dataloader()))
33 changes: 16 additions & 17 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,6 @@
from .test_utils import mocked_log


@pytest.fixture(scope="module")
def datamodule() -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=4,
patches_per_tile=2,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm


class TestBYOL:
def test_custom_augment_fn(self) -> None:
encoder = resnet18()
Expand All @@ -54,6 +37,22 @@ def test_custom_augment_fn(self) -> None:


class TestBYOLTask:
@pytest.fixture(scope="class")
def datamodule(self) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=4,
patches_per_tile=2,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture(params=["resnet18", "resnet50"])
def config(self, request: SubRequest) -> Dict[str, Any]:
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "byol.yaml"))
Expand Down
94 changes: 25 additions & 69 deletions tests/trainers/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,41 @@
from omegaconf import OmegaConf

from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.trainers import ChesapeakeCVPRSegmentationTask
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask

from .test_utils import FakeTrainer, mocked_log


@pytest.fixture(scope="module", params=[5, 7])
def class_set(request: SubRequest) -> int:
return cast(int, request.param)


@pytest.fixture(scope="module")
def datamodule(class_set: int) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=32,
patches_per_tile=2,
batch_size=2,
num_workers=0,
class_set=class_set,
)
dm.prepare_data()
dm.setup()
return dm


class TestChesapeakeCVPRSegmentationTask:
@pytest.fixture(
params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"])
)
def config(self, class_set: int, request: SubRequest) -> Dict[str, Any]:
@pytest.fixture(scope="class", params=[5, 7])
def class_set(self, request: SubRequest) -> int:
return cast(int, request.param)

@pytest.fixture(scope="class")
def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=32,
patches_per_tile=2,
batch_size=2,
num_workers=0,
class_set=class_set,
)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture
def config(self, class_set: int) -> Dict[str, Any]:
task_conf = OmegaConf.load(
os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml")
)
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
segmentation_model, loss = request.param
task_args["class_set"] = class_set
task_args["segmentation_model"] = segmentation_model
task_args["loss"] = loss
task_args["num_classes"] = class_set
return task_args

@pytest.fixture
Expand All @@ -64,46 +57,9 @@ def task(
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task

def test_configure_optimizers(self, task: ChesapeakeCVPRSegmentationTask) -> None:
out = task.configure_optimizers()
assert "optimizer" in out
assert "lr_scheduler" in out

def test_training(
self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)

def test_validation(
self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)

def test_test(
self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask
) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)

def test_invalid_class_set(self, config: Dict[str, Any]) -> None:
config["class_set"] = 6
error_message = "'class_set' must be either 5 or 7"
with pytest.raises(ValueError, match=error_message):
ChesapeakeCVPRSegmentationTask(**config)

def test_invalid_model(self, config: Dict[str, Any]) -> None:
config["segmentation_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
ChesapeakeCVPRSegmentationTask(**config)

def test_invalid_loss(self, config: Dict[str, Any]) -> None:
config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
ChesapeakeCVPRSegmentationTask(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torchgeo.datasets import CycloneDataModule
from torchgeo.trainers import (
ClassificationTask,
MultiLabelClassificationTask,
RegressionTask,
)
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask

from .test_utils import mocked_log

Expand Down Expand Up @@ -254,63 +249,3 @@ def test_invalid_loss(self, config: Dict[str, Any]) -> None:
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
MultiLabelClassificationTask(**config)


class TestRegressionTask:
@pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule:
root = os.path.join("tests", "data", "cyclone")
seed = 0
batch_size = 1
num_workers = 0
dm = CycloneDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

@pytest.fixture
def config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load(
os.path.join("conf", "task_defaults", "cyclone.yaml")
)
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args

@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> RegressionTask:
task = RegressionTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task

def test_configure_optimizers(self, task: RegressionTask) -> None:
out = task.configure_optimizers()
assert "optimizer" in out
assert "lr_scheduler" in out

def test_training(
self, datamodule: CycloneDataModule, task: RegressionTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)

def test_validation(
self, datamodule: CycloneDataModule, task: RegressionTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)

def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)

def test_invalid_model(self, config: Dict[str, Any]) -> None:
config["model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
RegressionTask(**config)
Loading